Spaces:
Sleeping
Sleeping
roychao19477
commited on
Commit
·
bd9ffb1
1
Parent(s):
a66fd6a
Add application file
Browse files- README.md +6 -7
- app.py +376 -0
- mamba_ssm/.DS_Store +0 -0
- mamba_ssm/__init__.py +5 -0
- mamba_ssm/models/__init__.py +0 -0
- mamba_ssm/models/config_mamba.py +15 -0
- mamba_ssm/models/mixer_seq_simple.py +264 -0
- mamba_ssm/modules/__init__.py +0 -0
- mamba_ssm/modules/mamba_simple.py +353 -0
- mamba_ssm/ops/__init__.py +0 -0
- mamba_ssm/ops/selective_scan_interface.py +357 -0
- mamba_ssm/ops/triton/__init__.py +0 -0
- mamba_ssm/ops/triton/layernorm.py +635 -0
- mamba_ssm/ops/triton/selective_state_update.py +263 -0
- mamba_ssm/utils/__init__.py +0 -0
- mamba_ssm/utils/generation.py +387 -0
- mamba_ssm/utils/hf.py +23 -0
- models/codec_module.py +183 -0
- models/discriminator.py +56 -0
- models/generator.py +72 -0
- models/loss.py +145 -0
- models/lsigmoid.py +66 -0
- models/mamba_block.py +110 -0
- models/pcs400.py +53 -0
- models/stfts.py +73 -0
- recipes/SEMamba_advanced.yaml +66 -0
- requirements.txt +22 -0
- yolov8n-face.pt +3 -0
README.md
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: green
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
short_description:
|
|
|
|
|
11 |
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Dev
|
3 |
+
colorFrom: purple
|
|
|
4 |
colorTo: indigo
|
5 |
sdk: gradio
|
6 |
+
sdk_version: 5.31.0
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
+
short_description: Dev
|
10 |
+
tags:
|
11 |
+
- Useless
|
12 |
---
|
|
|
|
app.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shlex
|
2 |
+
import subprocess
|
3 |
+
import spaces
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
import glob
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
# install packages for mamba
|
11 |
+
def install_mamba():
|
12 |
+
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
|
13 |
+
|
14 |
+
def clone_github():
|
15 |
+
subprocess.run([
|
16 |
+
"git", "clone",
|
17 |
+
f"https://RoyChao19477:{os.environ['GITHUB_TOKEN']}@github.com/RoyChao19477/for_HF_AVSEMamba.git",
|
18 |
+
])
|
19 |
+
# move all files except README.md
|
20 |
+
for item in glob.glob("for_HF_AVSEMamba/*"):
|
21 |
+
if os.path.basename(item) != "README.md":
|
22 |
+
if os.path.isdir(item):
|
23 |
+
shutil.move(item, ".")
|
24 |
+
else:
|
25 |
+
shutil.move(item, os.path.join(".", os.path.basename(item)))
|
26 |
+
|
27 |
+
#shutil.rmtree("tmp_repo")
|
28 |
+
#subprocess.run(["ls"], check=True)
|
29 |
+
|
30 |
+
install_mamba()
|
31 |
+
clone_github()
|
32 |
+
|
33 |
+
ABOUT = """
|
34 |
+
# SEMamba: Speech Enhancement
|
35 |
+
A Mamba-based model that denoises real-world audio.
|
36 |
+
Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram.
|
37 |
+
"""
|
38 |
+
|
39 |
+
|
40 |
+
import torch
|
41 |
+
import ffmpeg
|
42 |
+
import torchaudio
|
43 |
+
import torchaudio.transforms as T
|
44 |
+
import yaml
|
45 |
+
import librosa
|
46 |
+
import librosa.display
|
47 |
+
import matplotlib
|
48 |
+
import numpy as np
|
49 |
+
import soundfile as sf
|
50 |
+
import matplotlib.pyplot as plt
|
51 |
+
from models.stfts import mag_phase_stft, mag_phase_istft
|
52 |
+
from models.generator import SEMamba
|
53 |
+
from models.pcs400 import cal_pcs
|
54 |
+
from ultralytics import YOLO
|
55 |
+
import supervision as sv
|
56 |
+
|
57 |
+
import gradio as gr
|
58 |
+
import cv2
|
59 |
+
import os
|
60 |
+
import tempfile
|
61 |
+
from ultralytics import YOLO
|
62 |
+
from moviepy import ImageSequenceClip
|
63 |
+
from moviepy.video import fx as vfx
|
64 |
+
from scipy.io import wavfile
|
65 |
+
from avse_code import run_avse
|
66 |
+
|
67 |
+
# Load face detector
|
68 |
+
model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
|
69 |
+
|
70 |
+
|
71 |
+
from decord import VideoReader, cpu
|
72 |
+
from model import AVSEModule
|
73 |
+
from config import sampling_rate
|
74 |
+
import spaces
|
75 |
+
|
76 |
+
# Load model once globally
|
77 |
+
#ckpt_path = "ckpts/ep215_0906.oat.ckpt"
|
78 |
+
#model = AVSEModule.load_from_checkpoint(ckpt_path)
|
79 |
+
avse_model = AVSEModule()
|
80 |
+
#avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
|
81 |
+
avse_state_dict = torch.load("ckpts/ep220_0908.oat.ckpt")
|
82 |
+
avse_model.load_state_dict(avse_state_dict, strict=True)
|
83 |
+
avse_model.to("cuda")
|
84 |
+
avse_model.eval()
|
85 |
+
|
86 |
+
@spaces.GPU
|
87 |
+
def run_avse_inference(video_path, audio_path):
|
88 |
+
estimated = run_avse(video_path, audio_path)
|
89 |
+
# Load audio
|
90 |
+
#noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
|
91 |
+
#noisy = torch.tensor(noisy).unsqueeze(0) # (1, N)
|
92 |
+
noisy = wavfile.read(audio_path)[1].astype(np.float32) / (2 ** 15)
|
93 |
+
|
94 |
+
# Norm.
|
95 |
+
#noisy = noisy * (0.8 / np.max(np.abs(noisy)))
|
96 |
+
|
97 |
+
# Load grayscale video
|
98 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
99 |
+
frames = vr.get_batch(list(range(len(vr)))).asnumpy()
|
100 |
+
bg_frames = np.array([
|
101 |
+
cv2.cvtColor(frames[i], cv2.COLOR_RGB2GRAY) for i in range(len(frames))
|
102 |
+
]).astype(np.float32)
|
103 |
+
bg_frames /= 255.0
|
104 |
+
|
105 |
+
|
106 |
+
# Combine into input dict (match what model.enhance expects)
|
107 |
+
data = {
|
108 |
+
"noisy_audio": noisy,
|
109 |
+
"video_frames": bg_frames[np.newaxis, ...]
|
110 |
+
}
|
111 |
+
|
112 |
+
with torch.no_grad():
|
113 |
+
estimated = avse_model.enhance(data).reshape(-1)
|
114 |
+
|
115 |
+
# Save result
|
116 |
+
tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
|
117 |
+
sf.write(tmp_wav, estimated, samplerate=sampling_rate)
|
118 |
+
|
119 |
+
return tmp_wav
|
120 |
+
|
121 |
+
|
122 |
+
def extract_resampled_audio(video_path, target_sr=16000):
|
123 |
+
# Step 1: extract audio via torchaudio
|
124 |
+
# (moviepy will still extract it to wav temp file)
|
125 |
+
tmp_audio_path = tempfile.mktemp(suffix=".wav")
|
126 |
+
subprocess.run(["ffmpeg", "-y", "-i", video_path, "-vn", "-acodec", "pcm_s16le", "-ar", "44100", tmp_audio_path])
|
127 |
+
|
128 |
+
# Step 2: Load and resample
|
129 |
+
waveform, sr = torchaudio.load(tmp_audio_path)
|
130 |
+
if sr != target_sr:
|
131 |
+
resampler = T.Resample(orig_freq=sr, new_freq=target_sr)
|
132 |
+
waveform = resampler(waveform)
|
133 |
+
|
134 |
+
# Step 3: Save resampled audio
|
135 |
+
resampled_audio_path = tempfile.mktemp(suffix="_16k.wav")
|
136 |
+
torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
|
137 |
+
return resampled_audio_path
|
138 |
+
|
139 |
+
@spaces.GPU
|
140 |
+
def extract_faces(video_file):
|
141 |
+
# Step 0: Check resolution
|
142 |
+
cap = cv2.VideoCapture(video_file)
|
143 |
+
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
|
144 |
+
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
145 |
+
cap.release()
|
146 |
+
|
147 |
+
# Step 1: Downsample if needed
|
148 |
+
if width > 1280 or height > 720:
|
149 |
+
resized_path = tempfile.mktemp(suffix=".mp4")
|
150 |
+
subprocess.run([
|
151 |
+
"ffmpeg", "-y", "-i", video_file,
|
152 |
+
"-vf", "scale='min(1280,iw)':-2",
|
153 |
+
"-c:v", "libx264", "-crf", "28",
|
154 |
+
"-preset", "fast", "-an", resized_path
|
155 |
+
])
|
156 |
+
video_file = resized_path
|
157 |
+
|
158 |
+
cap = cv2.VideoCapture(video_file)
|
159 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
160 |
+
frames = []
|
161 |
+
|
162 |
+
while True:
|
163 |
+
ret, frame = cap.read()
|
164 |
+
if not ret:
|
165 |
+
break
|
166 |
+
|
167 |
+
# Inference
|
168 |
+
results = model(frame, verbose=False)[0]
|
169 |
+
for box in results.boxes:
|
170 |
+
# version 1
|
171 |
+
# x1, y1, x2, y2 = map(int, box.xyxy[0])
|
172 |
+
|
173 |
+
# version 2
|
174 |
+
h, w, _ = frame.shape
|
175 |
+
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
176 |
+
pad_ratio = 0.5 # 30% padding
|
177 |
+
|
178 |
+
dx = (x2 - x1) * pad_ratio
|
179 |
+
dy = (y2 - y1) * pad_ratio
|
180 |
+
|
181 |
+
x1 = int(max(0, x1 - dx))
|
182 |
+
y1 = int(max(0, y1 - dy))
|
183 |
+
x2 = int(min(w, x2 + dx))
|
184 |
+
y2 = int(min(h, y2 + dy))
|
185 |
+
# Added for v3
|
186 |
+
shift_down = int(0.1 * (y2 - y1))
|
187 |
+
y1 = int(min(max(0, y1 + shift_down), h))
|
188 |
+
y2 = int(min(max(0, y2 + shift_down), h))
|
189 |
+
face_crop = frame[y1:y2, x1:x2]
|
190 |
+
if face_crop.size != 0:
|
191 |
+
resized = cv2.resize(face_crop, (224, 224))
|
192 |
+
frames.append(resized)
|
193 |
+
|
194 |
+
#h_crop, w_crop = face_crop.shape[:2]
|
195 |
+
#side = min(h_crop, w_crop)
|
196 |
+
#start_y = (h_crop - side) // 2
|
197 |
+
#start_x = (w_crop - side) // 2
|
198 |
+
#square_crop = face_crop[start_y:start_y+side, start_x:start_x+side]
|
199 |
+
#resized = cv2.resize(square_crop, (224, 224))
|
200 |
+
#frames.append(resized)
|
201 |
+
|
202 |
+
break # only one face per frame
|
203 |
+
|
204 |
+
cap.release()
|
205 |
+
|
206 |
+
# Save as video
|
207 |
+
tmpdir = tempfile.mkdtemp()
|
208 |
+
output_path = os.path.join(tmpdir, "face_only_video.mp4")
|
209 |
+
#clip = ImageSequenceClip([cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames], fps=25)
|
210 |
+
#clip = ImageSequenceClip([cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames], fps=fps)
|
211 |
+
clip = ImageSequenceClip(
|
212 |
+
[cv2.cvtColor(cv2.resize(f, (224, 224)), cv2.COLOR_BGR2RGB) for f in frames],
|
213 |
+
fps=fps
|
214 |
+
).fx(vfx.flip_vertical)
|
215 |
+
clip.write_videofile(output_path, codec="libx264", audio=False, fps=25)
|
216 |
+
|
217 |
+
# Save audio from original, resampled to 16kHz
|
218 |
+
audio_path = os.path.join(tmpdir, "audio_16k.wav")
|
219 |
+
|
220 |
+
# Extract audio using ffmpeg-python (more robust than moviepy)
|
221 |
+
ffmpeg.input(video_file).output(
|
222 |
+
audio_path,
|
223 |
+
ar=16000, # resample to 16k
|
224 |
+
ac=1, # mono
|
225 |
+
format='wav',
|
226 |
+
vn=None # no video
|
227 |
+
).run(overwrite_output=True)
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
# ------------------------------- #
|
233 |
+
# AVSE models
|
234 |
+
|
235 |
+
enhanced_audio_path = run_avse_inference(output_path, audio_path)
|
236 |
+
|
237 |
+
|
238 |
+
return output_path, enhanced_audio_path
|
239 |
+
#return output_path, audio_path
|
240 |
+
|
241 |
+
iface = gr.Interface(
|
242 |
+
fn=extract_faces,
|
243 |
+
inputs=gr.Video(label="Upload or record your video"),
|
244 |
+
outputs=[
|
245 |
+
gr.Video(label="Detected Face Only Video"),
|
246 |
+
#gr.Audio(label="Extracted Audio (16kHz)", type="filepath"),
|
247 |
+
gr.Audio(label="Enhanced Audio", type="filepath")
|
248 |
+
],
|
249 |
+
title="Face Detector",
|
250 |
+
description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio."
|
251 |
+
)
|
252 |
+
|
253 |
+
iface.launch()
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
ckpt = "ckpts/SEMamba_advanced.pth"
|
258 |
+
cfg_f = "recipes/SEMamba_advanced.yaml"
|
259 |
+
|
260 |
+
# load config
|
261 |
+
with open(cfg_f, 'r') as f:
|
262 |
+
cfg = yaml.safe_load(f)
|
263 |
+
|
264 |
+
|
265 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
266 |
+
device = "cuda"
|
267 |
+
model = SEMamba(cfg).to(device)
|
268 |
+
#sdict = torch.load(ckpt, map_location=device)
|
269 |
+
#model.load_state_dict(sdict["generator"])
|
270 |
+
#model.eval()
|
271 |
+
|
272 |
+
@spaces.GPU
|
273 |
+
def enhance(filepath, model_name):
|
274 |
+
# Load model based on selection
|
275 |
+
ckpt_path = {
|
276 |
+
"VCTK-Demand": "ckpts/SEMamba_advanced.pth",
|
277 |
+
"VCTK+DNS": "ckpts/vd.pth"
|
278 |
+
}[model_name]
|
279 |
+
|
280 |
+
print("Loading:", ckpt_path)
|
281 |
+
model.load_state_dict(torch.load(ckpt_path, map_location=device)["generator"])
|
282 |
+
model.eval()
|
283 |
+
with torch.no_grad():
|
284 |
+
# load & resample
|
285 |
+
wav, orig_sr = librosa.load(filepath, sr=None)
|
286 |
+
noisy_wav = wav.copy()
|
287 |
+
if orig_sr != 16000:
|
288 |
+
wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000)
|
289 |
+
x = torch.from_numpy(wav).float().to(device)
|
290 |
+
norm = torch.sqrt(len(x)/torch.sum(x**2))
|
291 |
+
#x = (x * norm).unsqueeze(0)
|
292 |
+
x = (x * norm)
|
293 |
+
|
294 |
+
# split into 4s segments (64000 samples)
|
295 |
+
segment_len = 4 * 16000
|
296 |
+
chunks = x.split(segment_len)
|
297 |
+
enhanced_chunks = []
|
298 |
+
|
299 |
+
for chunk in chunks:
|
300 |
+
if len(chunk) < segment_len:
|
301 |
+
#pad = torch.zeros(segment_len - len(chunk), device=chunk.device)
|
302 |
+
pad = (torch.randn(segment_len - len(chunk), device=chunk.device) * 1e-4)
|
303 |
+
chunk = torch.cat([chunk, pad])
|
304 |
+
chunk = chunk.unsqueeze(0)
|
305 |
+
|
306 |
+
amp, pha, _ = mag_phase_stft(chunk, 400, 100, 400, 0.3)
|
307 |
+
amp2, pha2, _ = model(amp, pha)
|
308 |
+
out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
|
309 |
+
out = (out / norm).squeeze(0)
|
310 |
+
enhanced_chunks.append(out)
|
311 |
+
|
312 |
+
out = torch.cat(enhanced_chunks)[:len(x)].cpu().numpy() # trim padding
|
313 |
+
|
314 |
+
# back to original rate
|
315 |
+
if orig_sr != 16000:
|
316 |
+
out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr)
|
317 |
+
|
318 |
+
# Normalize
|
319 |
+
peak = np.max(np.abs(out))
|
320 |
+
if peak > 0.05:
|
321 |
+
out = out / peak * 0.85
|
322 |
+
|
323 |
+
# write file
|
324 |
+
sf.write("enhanced.wav", out, orig_sr)
|
325 |
+
|
326 |
+
# spectrograms
|
327 |
+
fig, axs = plt.subplots(1, 2, figsize=(16, 4))
|
328 |
+
|
329 |
+
# noisy
|
330 |
+
D_noisy = librosa.stft(noisy_wav, n_fft=512, hop_length=256)
|
331 |
+
S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max)
|
332 |
+
librosa.display.specshow(S_noisy, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[0], vmax=0)
|
333 |
+
axs[0].set_title("Noisy Spectrogram")
|
334 |
+
|
335 |
+
# enhanced
|
336 |
+
D_clean = librosa.stft(out, n_fft=512, hop_length=256)
|
337 |
+
S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max)
|
338 |
+
librosa.display.specshow(S_clean, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
|
339 |
+
#librosa.display.specshow(S_clean, sr=16000, hop_length=512, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
|
340 |
+
axs[1].set_title("Enhanced Spectrogram")
|
341 |
+
|
342 |
+
plt.tight_layout()
|
343 |
+
|
344 |
+
return "enhanced.wav", fig
|
345 |
+
|
346 |
+
#with gr.Blocks() as demo:
|
347 |
+
# gr.Markdown(ABOUT)
|
348 |
+
# input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
|
349 |
+
# enhance_btn = gr.Button("Enhance")
|
350 |
+
# output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
|
351 |
+
# plot_output = gr.Plot(label="Spectrograms")
|
352 |
+
#
|
353 |
+
# enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output])
|
354 |
+
#
|
355 |
+
#demo.queue().launch()
|
356 |
+
|
357 |
+
with gr.Blocks() as demo:
|
358 |
+
gr.Markdown(ABOUT)
|
359 |
+
input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
|
360 |
+
model_choice = gr.Radio(
|
361 |
+
label="Choose Model (The use of VCTK+DNS is recommended)",
|
362 |
+
choices=["VCTK-Demand", "VCTK+DNS"],
|
363 |
+
value="VCTK-Demand"
|
364 |
+
)
|
365 |
+
enhance_btn = gr.Button("Enhance")
|
366 |
+
output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
|
367 |
+
plot_output = gr.Plot(label="Spectrograms")
|
368 |
+
|
369 |
+
enhance_btn.click(
|
370 |
+
fn=enhance,
|
371 |
+
inputs=[input_audio, model_choice],
|
372 |
+
outputs=[output_audio, plot_output]
|
373 |
+
)
|
374 |
+
gr.Markdown("**Note**: The current models are trained on 16kHz audio. Therefore, any input audio not sampled at 16kHz will be automatically resampled before enhancement.")
|
375 |
+
|
376 |
+
demo.queue().launch()
|
mamba_ssm/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
mamba_ssm/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "1.2.2"
|
2 |
+
|
3 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
4 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
5 |
+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
mamba_ssm/models/__init__.py
ADDED
File without changes
|
mamba_ssm/models/config_mamba.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class MambaConfig:
|
6 |
+
|
7 |
+
d_model: int = 2560
|
8 |
+
n_layer: int = 64
|
9 |
+
vocab_size: int = 50277
|
10 |
+
ssm_cfg: dict = field(default_factory=dict)
|
11 |
+
rms_norm: bool = True
|
12 |
+
residual_in_fp32: bool = True
|
13 |
+
fused_add_norm: bool = True
|
14 |
+
pad_vocab_size_multiple: int = 8
|
15 |
+
tie_embeddings: bool = True
|
mamba_ssm/models/mixer_seq_simple.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from functools import partial
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
|
8 |
+
from collections import namedtuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
from mamba_ssm.models.config_mamba import MambaConfig
|
14 |
+
from mamba_ssm.modules.mamba_simple import Mamba, Block
|
15 |
+
from mamba_ssm.utils.generation import GenerationMixin
|
16 |
+
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
17 |
+
|
18 |
+
try:
|
19 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
20 |
+
except ImportError:
|
21 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
22 |
+
|
23 |
+
|
24 |
+
def create_block(
|
25 |
+
d_model,
|
26 |
+
ssm_cfg=None,
|
27 |
+
norm_epsilon=1e-5,
|
28 |
+
rms_norm=False,
|
29 |
+
residual_in_fp32=False,
|
30 |
+
fused_add_norm=False,
|
31 |
+
layer_idx=None,
|
32 |
+
device=None,
|
33 |
+
dtype=None,
|
34 |
+
):
|
35 |
+
if ssm_cfg is None:
|
36 |
+
ssm_cfg = {}
|
37 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
38 |
+
mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
|
39 |
+
norm_cls = partial(
|
40 |
+
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
41 |
+
)
|
42 |
+
block = Block(
|
43 |
+
d_model,
|
44 |
+
mixer_cls,
|
45 |
+
norm_cls=norm_cls,
|
46 |
+
fused_add_norm=fused_add_norm,
|
47 |
+
residual_in_fp32=residual_in_fp32,
|
48 |
+
)
|
49 |
+
block.layer_idx = layer_idx
|
50 |
+
return block
|
51 |
+
|
52 |
+
|
53 |
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
54 |
+
def _init_weights(
|
55 |
+
module,
|
56 |
+
n_layer,
|
57 |
+
initializer_range=0.02, # Now only used for embedding layer.
|
58 |
+
rescale_prenorm_residual=True,
|
59 |
+
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
60 |
+
):
|
61 |
+
if isinstance(module, nn.Linear):
|
62 |
+
if module.bias is not None:
|
63 |
+
if not getattr(module.bias, "_no_reinit", False):
|
64 |
+
nn.init.zeros_(module.bias)
|
65 |
+
elif isinstance(module, nn.Embedding):
|
66 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
67 |
+
|
68 |
+
if rescale_prenorm_residual:
|
69 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
70 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
71 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
72 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
73 |
+
#
|
74 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
75 |
+
for name, p in module.named_parameters():
|
76 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
77 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
78 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
79 |
+
# We need to reinit p since this code could be called multiple times
|
80 |
+
# Having just p *= scale would repeatedly scale it down
|
81 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
82 |
+
with torch.no_grad():
|
83 |
+
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
84 |
+
|
85 |
+
|
86 |
+
class MixerModel(nn.Module):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
d_model: int,
|
90 |
+
n_layer: int,
|
91 |
+
vocab_size: int,
|
92 |
+
ssm_cfg=None,
|
93 |
+
norm_epsilon: float = 1e-5,
|
94 |
+
rms_norm: bool = False,
|
95 |
+
initializer_cfg=None,
|
96 |
+
fused_add_norm=False,
|
97 |
+
residual_in_fp32=False,
|
98 |
+
device=None,
|
99 |
+
dtype=None,
|
100 |
+
) -> None:
|
101 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
102 |
+
super().__init__()
|
103 |
+
self.residual_in_fp32 = residual_in_fp32
|
104 |
+
|
105 |
+
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
106 |
+
|
107 |
+
# We change the order of residual and layer norm:
|
108 |
+
# Instead of LN -> Attn / MLP -> Add, we do:
|
109 |
+
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
110 |
+
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
111 |
+
# This is for performance reason: we can fuse add + layer_norm.
|
112 |
+
self.fused_add_norm = fused_add_norm
|
113 |
+
if self.fused_add_norm:
|
114 |
+
if layer_norm_fn is None or rms_norm_fn is None:
|
115 |
+
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
116 |
+
|
117 |
+
self.layers = nn.ModuleList(
|
118 |
+
[
|
119 |
+
create_block(
|
120 |
+
d_model,
|
121 |
+
ssm_cfg=ssm_cfg,
|
122 |
+
norm_epsilon=norm_epsilon,
|
123 |
+
rms_norm=rms_norm,
|
124 |
+
residual_in_fp32=residual_in_fp32,
|
125 |
+
fused_add_norm=fused_add_norm,
|
126 |
+
layer_idx=i,
|
127 |
+
**factory_kwargs,
|
128 |
+
)
|
129 |
+
for i in range(n_layer)
|
130 |
+
]
|
131 |
+
)
|
132 |
+
|
133 |
+
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
134 |
+
d_model, eps=norm_epsilon, **factory_kwargs
|
135 |
+
)
|
136 |
+
|
137 |
+
self.apply(
|
138 |
+
partial(
|
139 |
+
_init_weights,
|
140 |
+
n_layer=n_layer,
|
141 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
142 |
+
)
|
143 |
+
)
|
144 |
+
|
145 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
146 |
+
return {
|
147 |
+
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
148 |
+
for i, layer in enumerate(self.layers)
|
149 |
+
}
|
150 |
+
|
151 |
+
def forward(self, input_ids, inference_params=None):
|
152 |
+
hidden_states = self.embedding(input_ids)
|
153 |
+
residual = None
|
154 |
+
for layer in self.layers:
|
155 |
+
hidden_states, residual = layer(
|
156 |
+
hidden_states, residual, inference_params=inference_params
|
157 |
+
)
|
158 |
+
if not self.fused_add_norm:
|
159 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
160 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
161 |
+
else:
|
162 |
+
# Set prenorm=False here since we don't need the residual
|
163 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
|
164 |
+
hidden_states = fused_add_norm_fn(
|
165 |
+
hidden_states,
|
166 |
+
self.norm_f.weight,
|
167 |
+
self.norm_f.bias,
|
168 |
+
eps=self.norm_f.eps,
|
169 |
+
residual=residual,
|
170 |
+
prenorm=False,
|
171 |
+
residual_in_fp32=self.residual_in_fp32,
|
172 |
+
)
|
173 |
+
return hidden_states
|
174 |
+
|
175 |
+
|
176 |
+
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
177 |
+
|
178 |
+
def __init__(
|
179 |
+
self,
|
180 |
+
config: MambaConfig,
|
181 |
+
initializer_cfg=None,
|
182 |
+
device=None,
|
183 |
+
dtype=None,
|
184 |
+
) -> None:
|
185 |
+
self.config = config
|
186 |
+
d_model = config.d_model
|
187 |
+
n_layer = config.n_layer
|
188 |
+
vocab_size = config.vocab_size
|
189 |
+
ssm_cfg = config.ssm_cfg
|
190 |
+
rms_norm = config.rms_norm
|
191 |
+
residual_in_fp32 = config.residual_in_fp32
|
192 |
+
fused_add_norm = config.fused_add_norm
|
193 |
+
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
194 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
195 |
+
|
196 |
+
super().__init__()
|
197 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
198 |
+
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
199 |
+
self.backbone = MixerModel(
|
200 |
+
d_model=d_model,
|
201 |
+
n_layer=n_layer,
|
202 |
+
vocab_size=vocab_size,
|
203 |
+
ssm_cfg=ssm_cfg,
|
204 |
+
rms_norm=rms_norm,
|
205 |
+
initializer_cfg=initializer_cfg,
|
206 |
+
fused_add_norm=fused_add_norm,
|
207 |
+
residual_in_fp32=residual_in_fp32,
|
208 |
+
**factory_kwargs,
|
209 |
+
)
|
210 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
211 |
+
|
212 |
+
# Initialize weights and apply final processing
|
213 |
+
self.apply(
|
214 |
+
partial(
|
215 |
+
_init_weights,
|
216 |
+
n_layer=n_layer,
|
217 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
218 |
+
)
|
219 |
+
)
|
220 |
+
self.tie_weights()
|
221 |
+
|
222 |
+
def tie_weights(self):
|
223 |
+
if self.config.tie_embeddings:
|
224 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
225 |
+
|
226 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
227 |
+
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
228 |
+
|
229 |
+
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
|
230 |
+
"""
|
231 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
232 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
233 |
+
"""
|
234 |
+
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
235 |
+
if num_last_tokens > 0:
|
236 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
237 |
+
lm_logits = self.lm_head(hidden_states)
|
238 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
239 |
+
return CausalLMOutput(logits=lm_logits)
|
240 |
+
|
241 |
+
@classmethod
|
242 |
+
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
243 |
+
config_data = load_config_hf(pretrained_model_name)
|
244 |
+
config = MambaConfig(**config_data)
|
245 |
+
model = cls(config, device=device, dtype=dtype, **kwargs)
|
246 |
+
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
|
247 |
+
return model
|
248 |
+
|
249 |
+
def save_pretrained(self, save_directory):
|
250 |
+
"""
|
251 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
252 |
+
Save the model and its configuration file to a directory.
|
253 |
+
"""
|
254 |
+
# Ensure save_directory exists
|
255 |
+
os.makedirs(save_directory, exist_ok=True)
|
256 |
+
|
257 |
+
# Save the model's state_dict
|
258 |
+
model_path = os.path.join(save_directory, 'pytorch_model.bin')
|
259 |
+
torch.save(self.state_dict(), model_path)
|
260 |
+
|
261 |
+
# Save the configuration of the model
|
262 |
+
config_path = os.path.join(save_directory, 'config.json')
|
263 |
+
with open(config_path, 'w') as f:
|
264 |
+
json.dump(self.config.__dict__, f)
|
mamba_ssm/modules/__init__.py
ADDED
File without changes
|
mamba_ssm/modules/mamba_simple.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
|
13 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
14 |
+
|
15 |
+
try:
|
16 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
17 |
+
except ImportError:
|
18 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
19 |
+
|
20 |
+
try:
|
21 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
22 |
+
except ImportError:
|
23 |
+
selective_state_update = None
|
24 |
+
|
25 |
+
try:
|
26 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
27 |
+
except ImportError:
|
28 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
29 |
+
|
30 |
+
|
31 |
+
class Mamba(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
d_model,
|
35 |
+
d_state=16,
|
36 |
+
d_conv=4,
|
37 |
+
expand=2,
|
38 |
+
dt_rank="auto",
|
39 |
+
dt_min=0.001,
|
40 |
+
dt_max=0.1,
|
41 |
+
dt_init="random",
|
42 |
+
dt_scale=1.0,
|
43 |
+
dt_init_floor=1e-4,
|
44 |
+
conv_bias=True,
|
45 |
+
bias=False,
|
46 |
+
use_fast_path=True, # Fused kernel options
|
47 |
+
layer_idx=None,
|
48 |
+
device=None,
|
49 |
+
dtype=None,
|
50 |
+
):
|
51 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
52 |
+
super().__init__()
|
53 |
+
self.d_model = d_model
|
54 |
+
self.d_state = d_state
|
55 |
+
self.d_conv = d_conv
|
56 |
+
self.expand = expand
|
57 |
+
self.d_inner = int(self.expand * self.d_model)
|
58 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
59 |
+
self.use_fast_path = use_fast_path
|
60 |
+
self.layer_idx = layer_idx
|
61 |
+
|
62 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
63 |
+
|
64 |
+
self.conv1d = nn.Conv1d(
|
65 |
+
in_channels=self.d_inner,
|
66 |
+
out_channels=self.d_inner,
|
67 |
+
bias=conv_bias,
|
68 |
+
kernel_size=d_conv,
|
69 |
+
groups=self.d_inner,
|
70 |
+
padding=d_conv - 1,
|
71 |
+
**factory_kwargs,
|
72 |
+
)
|
73 |
+
|
74 |
+
self.activation = "silu"
|
75 |
+
self.act = nn.SiLU()
|
76 |
+
|
77 |
+
self.x_proj = nn.Linear(
|
78 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
79 |
+
)
|
80 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
81 |
+
|
82 |
+
# Initialize special dt projection to preserve variance at initialization
|
83 |
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
84 |
+
if dt_init == "constant":
|
85 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
86 |
+
elif dt_init == "random":
|
87 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
88 |
+
else:
|
89 |
+
raise NotImplementedError
|
90 |
+
|
91 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
92 |
+
dt = torch.exp(
|
93 |
+
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
94 |
+
+ math.log(dt_min)
|
95 |
+
).clamp(min=dt_init_floor)
|
96 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
97 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
98 |
+
with torch.no_grad():
|
99 |
+
self.dt_proj.bias.copy_(inv_dt)
|
100 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
101 |
+
self.dt_proj.bias._no_reinit = True
|
102 |
+
|
103 |
+
# S4D real initialization
|
104 |
+
A = repeat(
|
105 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
106 |
+
"n -> d n",
|
107 |
+
d=self.d_inner,
|
108 |
+
).contiguous()
|
109 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
110 |
+
self.A_log = nn.Parameter(A_log)
|
111 |
+
self.A_log._no_weight_decay = True
|
112 |
+
|
113 |
+
# D "skip" parameter
|
114 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
115 |
+
self.D._no_weight_decay = True
|
116 |
+
|
117 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
118 |
+
|
119 |
+
def forward(self, hidden_states, inference_params=None):
|
120 |
+
"""
|
121 |
+
hidden_states: (B, L, D)
|
122 |
+
Returns: same shape as hidden_states
|
123 |
+
"""
|
124 |
+
batch, seqlen, dim = hidden_states.shape
|
125 |
+
|
126 |
+
conv_state, ssm_state = None, None
|
127 |
+
if inference_params is not None:
|
128 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
129 |
+
if inference_params.seqlen_offset > 0:
|
130 |
+
# The states are updated inplace
|
131 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
132 |
+
return out
|
133 |
+
|
134 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
135 |
+
xz = rearrange(
|
136 |
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
137 |
+
"d (b l) -> b d l",
|
138 |
+
l=seqlen,
|
139 |
+
)
|
140 |
+
if self.in_proj.bias is not None:
|
141 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
142 |
+
|
143 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
144 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
145 |
+
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
|
146 |
+
out = mamba_inner_fn(
|
147 |
+
xz,
|
148 |
+
self.conv1d.weight,
|
149 |
+
self.conv1d.bias,
|
150 |
+
self.x_proj.weight,
|
151 |
+
self.dt_proj.weight,
|
152 |
+
self.out_proj.weight,
|
153 |
+
self.out_proj.bias,
|
154 |
+
A,
|
155 |
+
None, # input-dependent B
|
156 |
+
None, # input-dependent C
|
157 |
+
self.D.float(),
|
158 |
+
delta_bias=self.dt_proj.bias.float(),
|
159 |
+
delta_softplus=True,
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
x, z = xz.chunk(2, dim=1)
|
163 |
+
# Compute short convolution
|
164 |
+
if conv_state is not None:
|
165 |
+
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
166 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
167 |
+
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
|
168 |
+
if causal_conv1d_fn is None:
|
169 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
170 |
+
else:
|
171 |
+
assert self.activation in ["silu", "swish"]
|
172 |
+
x = causal_conv1d_fn(
|
173 |
+
x=x,
|
174 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
175 |
+
bias=self.conv1d.bias,
|
176 |
+
activation=self.activation,
|
177 |
+
)
|
178 |
+
|
179 |
+
# We're careful here about the layout, to avoid extra transposes.
|
180 |
+
# We want dt to have d as the slowest moving dimension
|
181 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
182 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
183 |
+
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
184 |
+
dt = self.dt_proj.weight @ dt.t()
|
185 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
186 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
187 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
188 |
+
assert self.activation in ["silu", "swish"]
|
189 |
+
y = selective_scan_fn(
|
190 |
+
x,
|
191 |
+
dt,
|
192 |
+
A,
|
193 |
+
B,
|
194 |
+
C,
|
195 |
+
self.D.float(),
|
196 |
+
z=z,
|
197 |
+
delta_bias=self.dt_proj.bias.float(),
|
198 |
+
delta_softplus=True,
|
199 |
+
return_last_state=ssm_state is not None,
|
200 |
+
)
|
201 |
+
if ssm_state is not None:
|
202 |
+
y, last_state = y
|
203 |
+
ssm_state.copy_(last_state)
|
204 |
+
y = rearrange(y, "b d l -> b l d")
|
205 |
+
out = self.out_proj(y)
|
206 |
+
return out
|
207 |
+
|
208 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
209 |
+
dtype = hidden_states.dtype
|
210 |
+
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
211 |
+
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
212 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
213 |
+
|
214 |
+
# Conv step
|
215 |
+
if causal_conv1d_update is None:
|
216 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
217 |
+
conv_state[:, :, -1] = x
|
218 |
+
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
219 |
+
if self.conv1d.bias is not None:
|
220 |
+
x = x + self.conv1d.bias
|
221 |
+
x = self.act(x).to(dtype=dtype)
|
222 |
+
else:
|
223 |
+
x = causal_conv1d_update(
|
224 |
+
x,
|
225 |
+
conv_state,
|
226 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
227 |
+
self.conv1d.bias,
|
228 |
+
self.activation,
|
229 |
+
)
|
230 |
+
|
231 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
232 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
233 |
+
# Don't add dt_bias here
|
234 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
235 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
236 |
+
|
237 |
+
# SSM step
|
238 |
+
if selective_state_update is None:
|
239 |
+
# Discretize A and B
|
240 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
241 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
242 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
243 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
244 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
245 |
+
y = y + self.D.to(dtype) * x
|
246 |
+
y = y * self.act(z) # (B D)
|
247 |
+
else:
|
248 |
+
y = selective_state_update(
|
249 |
+
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
250 |
+
)
|
251 |
+
|
252 |
+
out = self.out_proj(y)
|
253 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
254 |
+
|
255 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
256 |
+
device = self.out_proj.weight.device
|
257 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
258 |
+
conv_state = torch.zeros(
|
259 |
+
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
|
260 |
+
)
|
261 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
262 |
+
# ssm_dtype = torch.float32
|
263 |
+
ssm_state = torch.zeros(
|
264 |
+
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
|
265 |
+
)
|
266 |
+
return conv_state, ssm_state
|
267 |
+
|
268 |
+
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
269 |
+
assert self.layer_idx is not None
|
270 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
271 |
+
batch_shape = (batch_size,)
|
272 |
+
conv_state = torch.zeros(
|
273 |
+
batch_size,
|
274 |
+
self.d_model * self.expand,
|
275 |
+
self.d_conv,
|
276 |
+
device=self.conv1d.weight.device,
|
277 |
+
dtype=self.conv1d.weight.dtype,
|
278 |
+
)
|
279 |
+
ssm_state = torch.zeros(
|
280 |
+
batch_size,
|
281 |
+
self.d_model * self.expand,
|
282 |
+
self.d_state,
|
283 |
+
device=self.dt_proj.weight.device,
|
284 |
+
dtype=self.dt_proj.weight.dtype,
|
285 |
+
# dtype=torch.float32,
|
286 |
+
)
|
287 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
288 |
+
else:
|
289 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
290 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
291 |
+
if initialize_states:
|
292 |
+
conv_state.zero_()
|
293 |
+
ssm_state.zero_()
|
294 |
+
return conv_state, ssm_state
|
295 |
+
|
296 |
+
|
297 |
+
class Block(nn.Module):
|
298 |
+
def __init__(
|
299 |
+
self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
300 |
+
):
|
301 |
+
"""
|
302 |
+
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
303 |
+
|
304 |
+
This Block has a slightly different structure compared to a regular
|
305 |
+
prenorm Transformer block.
|
306 |
+
The standard block is: LN -> MHA/MLP -> Add.
|
307 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
308 |
+
Here we have: Add -> LN -> Mixer, returning both
|
309 |
+
the hidden_states (output of the mixer) and the residual.
|
310 |
+
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
311 |
+
The residual needs to be provided (except for the very first block).
|
312 |
+
"""
|
313 |
+
super().__init__()
|
314 |
+
self.residual_in_fp32 = residual_in_fp32
|
315 |
+
self.fused_add_norm = fused_add_norm
|
316 |
+
self.mixer = mixer_cls(dim)
|
317 |
+
self.norm = norm_cls(dim)
|
318 |
+
if self.fused_add_norm:
|
319 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
320 |
+
assert isinstance(
|
321 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
322 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
323 |
+
|
324 |
+
def forward(
|
325 |
+
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
326 |
+
):
|
327 |
+
r"""Pass the input through the encoder layer.
|
328 |
+
|
329 |
+
Args:
|
330 |
+
hidden_states: the sequence to the encoder layer (required).
|
331 |
+
residual: hidden_states = Mixer(LN(residual))
|
332 |
+
"""
|
333 |
+
if not self.fused_add_norm:
|
334 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
335 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
336 |
+
if self.residual_in_fp32:
|
337 |
+
residual = residual.to(torch.float32)
|
338 |
+
else:
|
339 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
|
340 |
+
hidden_states, residual = fused_add_norm_fn(
|
341 |
+
hidden_states,
|
342 |
+
self.norm.weight,
|
343 |
+
self.norm.bias,
|
344 |
+
residual=residual,
|
345 |
+
prenorm=True,
|
346 |
+
residual_in_fp32=self.residual_in_fp32,
|
347 |
+
eps=self.norm.eps,
|
348 |
+
)
|
349 |
+
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
350 |
+
return hidden_states, residual
|
351 |
+
|
352 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
353 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
mamba_ssm/ops/__init__.py
ADDED
File without changes
|
mamba_ssm/ops/selective_scan_interface.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
6 |
+
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
|
9 |
+
try:
|
10 |
+
from causal_conv1d import causal_conv1d_fn
|
11 |
+
import causal_conv1d_cuda
|
12 |
+
except ImportError:
|
13 |
+
causal_conv1d_fn = None
|
14 |
+
causal_conv1d_cuda = None
|
15 |
+
|
16 |
+
import selective_scan_cuda
|
17 |
+
|
18 |
+
|
19 |
+
class SelectiveScanFn(torch.autograd.Function):
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
23 |
+
return_last_state=False):
|
24 |
+
if u.stride(-1) != 1:
|
25 |
+
u = u.contiguous()
|
26 |
+
if delta.stride(-1) != 1:
|
27 |
+
delta = delta.contiguous()
|
28 |
+
if D is not None:
|
29 |
+
D = D.contiguous()
|
30 |
+
if B.stride(-1) != 1:
|
31 |
+
B = B.contiguous()
|
32 |
+
if C.stride(-1) != 1:
|
33 |
+
C = C.contiguous()
|
34 |
+
if z is not None and z.stride(-1) != 1:
|
35 |
+
z = z.contiguous()
|
36 |
+
if B.dim() == 3:
|
37 |
+
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
38 |
+
ctx.squeeze_B = True
|
39 |
+
if C.dim() == 3:
|
40 |
+
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
41 |
+
ctx.squeeze_C = True
|
42 |
+
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
|
43 |
+
ctx.delta_softplus = delta_softplus
|
44 |
+
ctx.has_z = z is not None
|
45 |
+
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
46 |
+
if not ctx.has_z:
|
47 |
+
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
48 |
+
return out if not return_last_state else (out, last_state)
|
49 |
+
else:
|
50 |
+
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
51 |
+
out_z = rest[0]
|
52 |
+
return out_z if not return_last_state else (out_z, last_state)
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def backward(ctx, dout, *args):
|
56 |
+
if not ctx.has_z:
|
57 |
+
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
58 |
+
z = None
|
59 |
+
out = None
|
60 |
+
else:
|
61 |
+
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
62 |
+
if dout.stride(-1) != 1:
|
63 |
+
dout = dout.contiguous()
|
64 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
65 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
66 |
+
# Here we just pass in None and dz will be allocated in the C++ code.
|
67 |
+
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
|
68 |
+
u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
|
69 |
+
False # option to recompute out_z, not used here
|
70 |
+
)
|
71 |
+
dz = rest[0] if ctx.has_z else None
|
72 |
+
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
73 |
+
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
74 |
+
return (du, ddelta, dA, dB, dC,
|
75 |
+
dD if D is not None else None,
|
76 |
+
dz,
|
77 |
+
ddelta_bias if delta_bias is not None else None,
|
78 |
+
None,
|
79 |
+
None)
|
80 |
+
|
81 |
+
|
82 |
+
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
83 |
+
return_last_state=False):
|
84 |
+
"""if return_last_state is True, returns (out, last_state)
|
85 |
+
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
86 |
+
not considered in the backward pass.
|
87 |
+
"""
|
88 |
+
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
|
89 |
+
|
90 |
+
|
91 |
+
def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
92 |
+
return_last_state=False):
|
93 |
+
"""
|
94 |
+
u: r(B D L)
|
95 |
+
delta: r(B D L)
|
96 |
+
A: c(D N) or r(D N)
|
97 |
+
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
98 |
+
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
99 |
+
D: r(D)
|
100 |
+
z: r(B D L)
|
101 |
+
delta_bias: r(D), fp32
|
102 |
+
|
103 |
+
out: r(B D L)
|
104 |
+
last_state (optional): r(B D dstate) or c(B D dstate)
|
105 |
+
"""
|
106 |
+
dtype_in = u.dtype
|
107 |
+
u = u.float()
|
108 |
+
delta = delta.float()
|
109 |
+
if delta_bias is not None:
|
110 |
+
delta = delta + delta_bias[..., None].float()
|
111 |
+
if delta_softplus:
|
112 |
+
delta = F.softplus(delta)
|
113 |
+
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
114 |
+
is_variable_B = B.dim() >= 3
|
115 |
+
is_variable_C = C.dim() >= 3
|
116 |
+
if A.is_complex():
|
117 |
+
if is_variable_B:
|
118 |
+
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
119 |
+
if is_variable_C:
|
120 |
+
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
121 |
+
else:
|
122 |
+
B = B.float()
|
123 |
+
C = C.float()
|
124 |
+
x = A.new_zeros((batch, dim, dstate))
|
125 |
+
ys = []
|
126 |
+
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
127 |
+
if not is_variable_B:
|
128 |
+
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
129 |
+
else:
|
130 |
+
if B.dim() == 3:
|
131 |
+
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
132 |
+
else:
|
133 |
+
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
134 |
+
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
135 |
+
if is_variable_C and C.dim() == 4:
|
136 |
+
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
137 |
+
last_state = None
|
138 |
+
for i in range(u.shape[2]):
|
139 |
+
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
140 |
+
if not is_variable_C:
|
141 |
+
y = torch.einsum('bdn,dn->bd', x, C)
|
142 |
+
else:
|
143 |
+
if C.dim() == 3:
|
144 |
+
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
145 |
+
else:
|
146 |
+
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
147 |
+
if i == u.shape[2] - 1:
|
148 |
+
last_state = x
|
149 |
+
if y.is_complex():
|
150 |
+
y = y.real * 2
|
151 |
+
ys.append(y)
|
152 |
+
y = torch.stack(ys, dim=2) # (batch dim L)
|
153 |
+
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
154 |
+
if z is not None:
|
155 |
+
out = out * F.silu(z)
|
156 |
+
out = out.to(dtype=dtype_in)
|
157 |
+
return out if not return_last_state else (out, last_state)
|
158 |
+
|
159 |
+
|
160 |
+
class MambaInnerFn(torch.autograd.Function):
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
@custom_fwd
|
164 |
+
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
165 |
+
out_proj_weight, out_proj_bias,
|
166 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
167 |
+
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
|
168 |
+
"""
|
169 |
+
xz: (batch, dim, seqlen)
|
170 |
+
"""
|
171 |
+
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
172 |
+
assert checkpoint_lvl in [0, 1]
|
173 |
+
L = xz.shape[-1]
|
174 |
+
delta_rank = delta_proj_weight.shape[1]
|
175 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
176 |
+
if torch.is_autocast_enabled():
|
177 |
+
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
178 |
+
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
179 |
+
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
180 |
+
out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
181 |
+
if out_proj_bias is not None else None)
|
182 |
+
if xz.stride(-1) != 1:
|
183 |
+
xz = xz.contiguous()
|
184 |
+
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
185 |
+
x, z = xz.chunk(2, dim=1)
|
186 |
+
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
187 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
188 |
+
x, conv1d_weight, conv1d_bias, None, None, None, True
|
189 |
+
)
|
190 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
191 |
+
# We want delta to have d as the slowest moving dimension
|
192 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
193 |
+
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
194 |
+
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
|
195 |
+
ctx.is_variable_B = B is None
|
196 |
+
ctx.is_variable_C = C is None
|
197 |
+
ctx.B_proj_bias_is_None = B_proj_bias is None
|
198 |
+
ctx.C_proj_bias_is_None = C_proj_bias is None
|
199 |
+
if B is None: # variable B
|
200 |
+
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
|
201 |
+
if B_proj_bias is not None:
|
202 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
203 |
+
if not A.is_complex():
|
204 |
+
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
205 |
+
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
206 |
+
else:
|
207 |
+
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
208 |
+
else:
|
209 |
+
if B.stride(-1) != 1:
|
210 |
+
B = B.contiguous()
|
211 |
+
if C is None: # variable C
|
212 |
+
C = x_dbl[:, -d_state:] # (bl dstate)
|
213 |
+
if C_proj_bias is not None:
|
214 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
215 |
+
if not A.is_complex():
|
216 |
+
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
217 |
+
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
218 |
+
else:
|
219 |
+
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
220 |
+
else:
|
221 |
+
if C.stride(-1) != 1:
|
222 |
+
C = C.contiguous()
|
223 |
+
if D is not None:
|
224 |
+
D = D.contiguous()
|
225 |
+
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
|
226 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
227 |
+
)
|
228 |
+
ctx.delta_softplus = delta_softplus
|
229 |
+
ctx.out_proj_bias_is_None = out_proj_bias is None
|
230 |
+
ctx.checkpoint_lvl = checkpoint_lvl
|
231 |
+
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
|
232 |
+
conv1d_out, delta = None, None
|
233 |
+
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
|
234 |
+
delta_proj_weight, out_proj_weight, conv1d_out, delta,
|
235 |
+
A, B, C, D, delta_bias, scan_intermediates, out)
|
236 |
+
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
237 |
+
|
238 |
+
@staticmethod
|
239 |
+
@custom_bwd
|
240 |
+
def backward(ctx, dout):
|
241 |
+
# dout: (batch, seqlen, dim)
|
242 |
+
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
243 |
+
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
|
244 |
+
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
|
245 |
+
L = xz.shape[-1]
|
246 |
+
delta_rank = delta_proj_weight.shape[1]
|
247 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
248 |
+
x, z = xz.chunk(2, dim=1)
|
249 |
+
if dout.stride(-1) != 1:
|
250 |
+
dout = dout.contiguous()
|
251 |
+
if ctx.checkpoint_lvl == 1:
|
252 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
253 |
+
x, conv1d_weight, conv1d_bias, None, None, None, True
|
254 |
+
)
|
255 |
+
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
|
256 |
+
"d (b l) -> b d l", l = L)
|
257 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
258 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
259 |
+
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
260 |
+
dx, dz = dxz.chunk(2, dim=1)
|
261 |
+
dout = rearrange(dout, "b l e -> e (b l)")
|
262 |
+
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
263 |
+
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
|
264 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
|
265 |
+
ctx.delta_softplus,
|
266 |
+
True # option to recompute out_z
|
267 |
+
)
|
268 |
+
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
|
269 |
+
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
270 |
+
dD = dD if D is not None else None
|
271 |
+
dx_dbl = torch.empty_like(x_dbl)
|
272 |
+
dB_proj_bias = None
|
273 |
+
if ctx.is_variable_B:
|
274 |
+
if not A.is_complex():
|
275 |
+
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
276 |
+
else:
|
277 |
+
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
278 |
+
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
279 |
+
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
|
280 |
+
dB = None
|
281 |
+
dC_proj_bias = None
|
282 |
+
if ctx.is_variable_C:
|
283 |
+
if not A.is_complex():
|
284 |
+
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
285 |
+
else:
|
286 |
+
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
287 |
+
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
288 |
+
dx_dbl[:, -d_state:] = dC # (bl d)
|
289 |
+
dC = None
|
290 |
+
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
291 |
+
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
292 |
+
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
293 |
+
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
294 |
+
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
295 |
+
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
296 |
+
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
297 |
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
298 |
+
# backward of conv1d with the backward of chunk).
|
299 |
+
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
300 |
+
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
|
301 |
+
)
|
302 |
+
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
303 |
+
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
304 |
+
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
|
305 |
+
dout_proj_weight, dout_proj_bias,
|
306 |
+
dA, dB, dC, dD,
|
307 |
+
ddelta_bias if delta_bias is not None else None,
|
308 |
+
dB_proj_bias, dC_proj_bias, None)
|
309 |
+
|
310 |
+
|
311 |
+
def mamba_inner_fn(
|
312 |
+
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
313 |
+
out_proj_weight, out_proj_bias,
|
314 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
315 |
+
C_proj_bias=None, delta_softplus=True
|
316 |
+
):
|
317 |
+
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
318 |
+
out_proj_weight, out_proj_bias,
|
319 |
+
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
|
320 |
+
|
321 |
+
|
322 |
+
def mamba_inner_ref(
|
323 |
+
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
324 |
+
out_proj_weight, out_proj_bias,
|
325 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
326 |
+
C_proj_bias=None, delta_softplus=True
|
327 |
+
):
|
328 |
+
assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
|
329 |
+
L = xz.shape[-1]
|
330 |
+
delta_rank = delta_proj_weight.shape[1]
|
331 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
332 |
+
x, z = xz.chunk(2, dim=1)
|
333 |
+
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
|
334 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
335 |
+
# We want delta to have d as the slowest moving dimension
|
336 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
337 |
+
x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
338 |
+
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
339 |
+
delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
340 |
+
if B is None: # variable B
|
341 |
+
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
|
342 |
+
if B_proj_bias is not None:
|
343 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
344 |
+
if not A.is_complex():
|
345 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
346 |
+
else:
|
347 |
+
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
348 |
+
if C is None: # variable B
|
349 |
+
C = x_dbl[:, -d_state:] # (bl d)
|
350 |
+
if C_proj_bias is not None:
|
351 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
352 |
+
if not A.is_complex():
|
353 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
354 |
+
else:
|
355 |
+
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
356 |
+
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
|
357 |
+
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
mamba_ssm/ops/triton/__init__.py
ADDED
File without changes
|
mamba_ssm/ops/triton/layernorm.py
ADDED
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao.
|
2 |
+
# Implement residual + layer_norm / rms_norm.
|
3 |
+
|
4 |
+
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
5 |
+
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
6 |
+
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
7 |
+
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
8 |
+
|
9 |
+
import math
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.cuda.amp import custom_fwd, custom_bwd
|
14 |
+
|
15 |
+
import triton
|
16 |
+
import triton.language as tl
|
17 |
+
|
18 |
+
|
19 |
+
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
20 |
+
dtype = x.dtype
|
21 |
+
if upcast:
|
22 |
+
weight = weight.float()
|
23 |
+
bias = bias.float() if bias is not None else None
|
24 |
+
if upcast:
|
25 |
+
x = x.float()
|
26 |
+
residual = residual.float() if residual is not None else residual
|
27 |
+
if residual is not None:
|
28 |
+
x = (x + residual).to(x.dtype)
|
29 |
+
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
|
30 |
+
dtype
|
31 |
+
)
|
32 |
+
return out if not prenorm else (out, x)
|
33 |
+
|
34 |
+
|
35 |
+
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
36 |
+
dtype = x.dtype
|
37 |
+
if upcast:
|
38 |
+
weight = weight.float()
|
39 |
+
bias = bias.float() if bias is not None else None
|
40 |
+
if upcast:
|
41 |
+
x = x.float()
|
42 |
+
residual = residual.float() if residual is not None else residual
|
43 |
+
if residual is not None:
|
44 |
+
x = (x + residual).to(x.dtype)
|
45 |
+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
46 |
+
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
47 |
+
out = out.to(dtype)
|
48 |
+
return out if not prenorm else (out, x)
|
49 |
+
|
50 |
+
|
51 |
+
@triton.autotune(
|
52 |
+
configs=[
|
53 |
+
triton.Config({}, num_warps=1),
|
54 |
+
triton.Config({}, num_warps=2),
|
55 |
+
triton.Config({}, num_warps=4),
|
56 |
+
triton.Config({}, num_warps=8),
|
57 |
+
triton.Config({}, num_warps=16),
|
58 |
+
triton.Config({}, num_warps=32),
|
59 |
+
],
|
60 |
+
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
61 |
+
)
|
62 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
63 |
+
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
64 |
+
@triton.jit
|
65 |
+
def _layer_norm_fwd_1pass_kernel(
|
66 |
+
X, # pointer to the input
|
67 |
+
Y, # pointer to the output
|
68 |
+
W, # pointer to the weights
|
69 |
+
B, # pointer to the biases
|
70 |
+
RESIDUAL, # pointer to the residual
|
71 |
+
RESIDUAL_OUT, # pointer to the residual
|
72 |
+
Mean, # pointer to the mean
|
73 |
+
Rstd, # pointer to the 1/std
|
74 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
75 |
+
stride_y_row,
|
76 |
+
stride_res_row,
|
77 |
+
stride_res_out_row,
|
78 |
+
N, # number of columns in X
|
79 |
+
eps, # epsilon to avoid division by zero
|
80 |
+
IS_RMS_NORM: tl.constexpr,
|
81 |
+
BLOCK_N: tl.constexpr,
|
82 |
+
HAS_RESIDUAL: tl.constexpr,
|
83 |
+
STORE_RESIDUAL_OUT: tl.constexpr,
|
84 |
+
HAS_BIAS: tl.constexpr,
|
85 |
+
):
|
86 |
+
# Map the program id to the row of X and Y it should compute.
|
87 |
+
row = tl.program_id(0)
|
88 |
+
X += row * stride_x_row
|
89 |
+
Y += row * stride_y_row
|
90 |
+
if HAS_RESIDUAL:
|
91 |
+
RESIDUAL += row * stride_res_row
|
92 |
+
if STORE_RESIDUAL_OUT:
|
93 |
+
RESIDUAL_OUT += row * stride_res_out_row
|
94 |
+
# Compute mean and variance
|
95 |
+
cols = tl.arange(0, BLOCK_N)
|
96 |
+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
97 |
+
if HAS_RESIDUAL:
|
98 |
+
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
99 |
+
x += residual
|
100 |
+
if STORE_RESIDUAL_OUT:
|
101 |
+
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
102 |
+
if not IS_RMS_NORM:
|
103 |
+
mean = tl.sum(x, axis=0) / N
|
104 |
+
tl.store(Mean + row, mean)
|
105 |
+
xbar = tl.where(cols < N, x - mean, 0.0)
|
106 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
107 |
+
else:
|
108 |
+
xbar = tl.where(cols < N, x, 0.0)
|
109 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
110 |
+
rstd = 1 / tl.sqrt(var + eps)
|
111 |
+
tl.store(Rstd + row, rstd)
|
112 |
+
# Normalize and apply linear transformation
|
113 |
+
mask = cols < N
|
114 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
115 |
+
if HAS_BIAS:
|
116 |
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
117 |
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
118 |
+
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
119 |
+
# Write output
|
120 |
+
tl.store(Y + cols, y, mask=mask)
|
121 |
+
|
122 |
+
|
123 |
+
def _layer_norm_fwd(
|
124 |
+
x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
|
125 |
+
):
|
126 |
+
if residual is not None:
|
127 |
+
residual_dtype = residual.dtype
|
128 |
+
M, N = x.shape
|
129 |
+
assert x.stride(-1) == 1
|
130 |
+
if residual is not None:
|
131 |
+
assert residual.stride(-1) == 1
|
132 |
+
assert residual.shape == (M, N)
|
133 |
+
assert weight.shape == (N,)
|
134 |
+
assert weight.stride(-1) == 1
|
135 |
+
if bias is not None:
|
136 |
+
assert bias.stride(-1) == 1
|
137 |
+
assert bias.shape == (N,)
|
138 |
+
# allocate output
|
139 |
+
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
140 |
+
assert y.stride(-1) == 1
|
141 |
+
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
|
142 |
+
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
|
143 |
+
assert residual_out.stride(-1) == 1
|
144 |
+
else:
|
145 |
+
residual_out = None
|
146 |
+
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
|
147 |
+
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
148 |
+
# Less than 64KB per feature: enqueue fused kernel
|
149 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
150 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
151 |
+
if N > BLOCK_N:
|
152 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
153 |
+
# heuristics for number of warps
|
154 |
+
with torch.cuda.device(x.device.index):
|
155 |
+
_layer_norm_fwd_1pass_kernel[(M,)](
|
156 |
+
x,
|
157 |
+
y,
|
158 |
+
weight,
|
159 |
+
bias,
|
160 |
+
residual,
|
161 |
+
residual_out,
|
162 |
+
mean,
|
163 |
+
rstd,
|
164 |
+
x.stride(0),
|
165 |
+
y.stride(0),
|
166 |
+
residual.stride(0) if residual is not None else 0,
|
167 |
+
residual_out.stride(0) if residual_out is not None else 0,
|
168 |
+
N,
|
169 |
+
eps,
|
170 |
+
is_rms_norm,
|
171 |
+
BLOCK_N,
|
172 |
+
residual is not None,
|
173 |
+
residual_out is not None,
|
174 |
+
bias is not None,
|
175 |
+
)
|
176 |
+
# residual_out is None if residual is None and residual_dtype == input_dtype
|
177 |
+
return y, mean, rstd, residual_out if residual_out is not None else x
|
178 |
+
|
179 |
+
|
180 |
+
@triton.autotune(
|
181 |
+
configs=[
|
182 |
+
triton.Config({}, num_warps=1),
|
183 |
+
triton.Config({}, num_warps=2),
|
184 |
+
triton.Config({}, num_warps=4),
|
185 |
+
triton.Config({}, num_warps=8),
|
186 |
+
triton.Config({}, num_warps=16),
|
187 |
+
triton.Config({}, num_warps=32),
|
188 |
+
],
|
189 |
+
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
|
190 |
+
)
|
191 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
192 |
+
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
193 |
+
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
194 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
195 |
+
@triton.jit
|
196 |
+
def _layer_norm_bwd_kernel(
|
197 |
+
X, # pointer to the input
|
198 |
+
W, # pointer to the weights
|
199 |
+
B, # pointer to the biases
|
200 |
+
Y, # pointer to the output to be recomputed
|
201 |
+
DY, # pointer to the output gradient
|
202 |
+
DX, # pointer to the input gradient
|
203 |
+
DW, # pointer to the partial sum of weights gradient
|
204 |
+
DB, # pointer to the partial sum of biases gradient
|
205 |
+
DRESIDUAL,
|
206 |
+
DRESIDUAL_IN,
|
207 |
+
Mean, # pointer to the mean
|
208 |
+
Rstd, # pointer to the 1/std
|
209 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
210 |
+
stride_y_row,
|
211 |
+
stride_dy_row,
|
212 |
+
stride_dx_row,
|
213 |
+
stride_dres_row,
|
214 |
+
stride_dres_in_row,
|
215 |
+
M, # number of rows in X
|
216 |
+
N, # number of columns in X
|
217 |
+
eps, # epsilon to avoid division by zero
|
218 |
+
rows_per_program,
|
219 |
+
IS_RMS_NORM: tl.constexpr,
|
220 |
+
BLOCK_N: tl.constexpr,
|
221 |
+
HAS_DRESIDUAL: tl.constexpr,
|
222 |
+
STORE_DRESIDUAL: tl.constexpr,
|
223 |
+
HAS_BIAS: tl.constexpr,
|
224 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
225 |
+
):
|
226 |
+
# Map the program id to the elements of X, DX, and DY it should compute.
|
227 |
+
row_block_id = tl.program_id(0)
|
228 |
+
row_start = row_block_id * rows_per_program
|
229 |
+
cols = tl.arange(0, BLOCK_N)
|
230 |
+
mask = cols < N
|
231 |
+
X += row_start * stride_x_row
|
232 |
+
if HAS_DRESIDUAL:
|
233 |
+
DRESIDUAL += row_start * stride_dres_row
|
234 |
+
if STORE_DRESIDUAL:
|
235 |
+
DRESIDUAL_IN += row_start * stride_dres_in_row
|
236 |
+
DY += row_start * stride_dy_row
|
237 |
+
DX += row_start * stride_dx_row
|
238 |
+
if RECOMPUTE_OUTPUT:
|
239 |
+
Y += row_start * stride_y_row
|
240 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
241 |
+
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
242 |
+
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
243 |
+
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
244 |
+
if HAS_BIAS:
|
245 |
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
246 |
+
row_end = min((row_block_id + 1) * rows_per_program, M)
|
247 |
+
for row in range(row_start, row_end):
|
248 |
+
# Load data to SRAM
|
249 |
+
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
250 |
+
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
251 |
+
if not IS_RMS_NORM:
|
252 |
+
mean = tl.load(Mean + row)
|
253 |
+
rstd = tl.load(Rstd + row)
|
254 |
+
# Compute dx
|
255 |
+
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
256 |
+
xhat = tl.where(mask, xhat, 0.0)
|
257 |
+
if RECOMPUTE_OUTPUT:
|
258 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
259 |
+
tl.store(Y + cols, y, mask=mask)
|
260 |
+
wdy = w * dy
|
261 |
+
dw += dy * xhat
|
262 |
+
if HAS_BIAS:
|
263 |
+
db += dy
|
264 |
+
if not IS_RMS_NORM:
|
265 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
266 |
+
c2 = tl.sum(wdy, axis=0) / N
|
267 |
+
dx = (wdy - (xhat * c1 + c2)) * rstd
|
268 |
+
else:
|
269 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
270 |
+
dx = (wdy - xhat * c1) * rstd
|
271 |
+
if HAS_DRESIDUAL:
|
272 |
+
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
273 |
+
dx += dres
|
274 |
+
# Write dx
|
275 |
+
if STORE_DRESIDUAL:
|
276 |
+
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
277 |
+
tl.store(DX + cols, dx, mask=mask)
|
278 |
+
|
279 |
+
X += stride_x_row
|
280 |
+
if HAS_DRESIDUAL:
|
281 |
+
DRESIDUAL += stride_dres_row
|
282 |
+
if STORE_DRESIDUAL:
|
283 |
+
DRESIDUAL_IN += stride_dres_in_row
|
284 |
+
if RECOMPUTE_OUTPUT:
|
285 |
+
Y += stride_y_row
|
286 |
+
DY += stride_dy_row
|
287 |
+
DX += stride_dx_row
|
288 |
+
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
289 |
+
if HAS_BIAS:
|
290 |
+
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
291 |
+
|
292 |
+
|
293 |
+
def _layer_norm_bwd(
|
294 |
+
dy,
|
295 |
+
x,
|
296 |
+
weight,
|
297 |
+
bias,
|
298 |
+
eps,
|
299 |
+
mean,
|
300 |
+
rstd,
|
301 |
+
dresidual=None,
|
302 |
+
has_residual=False,
|
303 |
+
is_rms_norm=False,
|
304 |
+
x_dtype=None,
|
305 |
+
recompute_output=False,
|
306 |
+
):
|
307 |
+
M, N = x.shape
|
308 |
+
assert x.stride(-1) == 1
|
309 |
+
assert dy.stride(-1) == 1
|
310 |
+
assert dy.shape == (M, N)
|
311 |
+
if dresidual is not None:
|
312 |
+
assert dresidual.stride(-1) == 1
|
313 |
+
assert dresidual.shape == (M, N)
|
314 |
+
assert weight.shape == (N,)
|
315 |
+
assert weight.stride(-1) == 1
|
316 |
+
if bias is not None:
|
317 |
+
assert bias.stride(-1) == 1
|
318 |
+
assert bias.shape == (N,)
|
319 |
+
# allocate output
|
320 |
+
dx = (
|
321 |
+
torch.empty_like(x)
|
322 |
+
if x_dtype is None
|
323 |
+
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
324 |
+
)
|
325 |
+
dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
|
326 |
+
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
|
327 |
+
|
328 |
+
# Less than 64KB per feature: enqueue fused kernel
|
329 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
330 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
331 |
+
if N > BLOCK_N:
|
332 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
333 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
334 |
+
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
335 |
+
_db = (
|
336 |
+
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
337 |
+
if bias is not None
|
338 |
+
else None
|
339 |
+
)
|
340 |
+
rows_per_program = math.ceil(M / sm_count)
|
341 |
+
grid = (sm_count,)
|
342 |
+
with torch.cuda.device(x.device.index):
|
343 |
+
_layer_norm_bwd_kernel[grid](
|
344 |
+
x,
|
345 |
+
weight,
|
346 |
+
bias,
|
347 |
+
y,
|
348 |
+
dy,
|
349 |
+
dx,
|
350 |
+
_dw,
|
351 |
+
_db,
|
352 |
+
dresidual,
|
353 |
+
dresidual_in,
|
354 |
+
mean,
|
355 |
+
rstd,
|
356 |
+
x.stride(0),
|
357 |
+
0 if not recompute_output else y.stride(0),
|
358 |
+
dy.stride(0),
|
359 |
+
dx.stride(0),
|
360 |
+
dresidual.stride(0) if dresidual is not None else 0,
|
361 |
+
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
362 |
+
M,
|
363 |
+
N,
|
364 |
+
eps,
|
365 |
+
rows_per_program,
|
366 |
+
is_rms_norm,
|
367 |
+
BLOCK_N,
|
368 |
+
dresidual is not None,
|
369 |
+
dresidual_in is not None,
|
370 |
+
bias is not None,
|
371 |
+
)
|
372 |
+
dw = _dw.sum(0).to(weight.dtype)
|
373 |
+
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
374 |
+
# Don't need to compute dresidual_in separately in this case
|
375 |
+
if has_residual and dx.dtype == x.dtype:
|
376 |
+
dresidual_in = dx
|
377 |
+
return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
|
378 |
+
|
379 |
+
|
380 |
+
class LayerNormFn(torch.autograd.Function):
|
381 |
+
@staticmethod
|
382 |
+
def forward(
|
383 |
+
ctx,
|
384 |
+
x,
|
385 |
+
weight,
|
386 |
+
bias,
|
387 |
+
residual=None,
|
388 |
+
eps=1e-6,
|
389 |
+
prenorm=False,
|
390 |
+
residual_in_fp32=False,
|
391 |
+
is_rms_norm=False,
|
392 |
+
):
|
393 |
+
x_shape_og = x.shape
|
394 |
+
# reshape input data into 2D tensor
|
395 |
+
x = x.reshape(-1, x.shape[-1])
|
396 |
+
if x.stride(-1) != 1:
|
397 |
+
x = x.contiguous()
|
398 |
+
if residual is not None:
|
399 |
+
assert residual.shape == x_shape_og
|
400 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
401 |
+
if residual.stride(-1) != 1:
|
402 |
+
residual = residual.contiguous()
|
403 |
+
weight = weight.contiguous()
|
404 |
+
if bias is not None:
|
405 |
+
bias = bias.contiguous()
|
406 |
+
residual_dtype = (
|
407 |
+
residual.dtype
|
408 |
+
if residual is not None
|
409 |
+
else (torch.float32 if residual_in_fp32 else None)
|
410 |
+
)
|
411 |
+
y, mean, rstd, residual_out = _layer_norm_fwd(
|
412 |
+
x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
|
413 |
+
)
|
414 |
+
ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
|
415 |
+
ctx.x_shape_og = x_shape_og
|
416 |
+
ctx.eps = eps
|
417 |
+
ctx.is_rms_norm = is_rms_norm
|
418 |
+
ctx.has_residual = residual is not None
|
419 |
+
ctx.prenorm = prenorm
|
420 |
+
ctx.x_dtype = x.dtype
|
421 |
+
y = y.reshape(x_shape_og)
|
422 |
+
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
|
423 |
+
|
424 |
+
@staticmethod
|
425 |
+
def backward(ctx, dy, *args):
|
426 |
+
x, weight, bias, mean, rstd = ctx.saved_tensors
|
427 |
+
dy = dy.reshape(-1, dy.shape[-1])
|
428 |
+
if dy.stride(-1) != 1:
|
429 |
+
dy = dy.contiguous()
|
430 |
+
assert dy.shape == x.shape
|
431 |
+
if ctx.prenorm:
|
432 |
+
dresidual = args[0]
|
433 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
434 |
+
if dresidual.stride(-1) != 1:
|
435 |
+
dresidual = dresidual.contiguous()
|
436 |
+
assert dresidual.shape == x.shape
|
437 |
+
else:
|
438 |
+
dresidual = None
|
439 |
+
dx, dw, db, dresidual_in = _layer_norm_bwd(
|
440 |
+
dy,
|
441 |
+
x,
|
442 |
+
weight,
|
443 |
+
bias,
|
444 |
+
ctx.eps,
|
445 |
+
mean,
|
446 |
+
rstd,
|
447 |
+
dresidual,
|
448 |
+
ctx.has_residual,
|
449 |
+
ctx.is_rms_norm,
|
450 |
+
x_dtype=ctx.x_dtype,
|
451 |
+
)
|
452 |
+
return (
|
453 |
+
dx.reshape(ctx.x_shape_og),
|
454 |
+
dw,
|
455 |
+
db,
|
456 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
457 |
+
None,
|
458 |
+
None,
|
459 |
+
None,
|
460 |
+
None,
|
461 |
+
)
|
462 |
+
|
463 |
+
|
464 |
+
def layer_norm_fn(
|
465 |
+
x,
|
466 |
+
weight,
|
467 |
+
bias,
|
468 |
+
residual=None,
|
469 |
+
eps=1e-6,
|
470 |
+
prenorm=False,
|
471 |
+
residual_in_fp32=False,
|
472 |
+
is_rms_norm=False,
|
473 |
+
):
|
474 |
+
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)
|
475 |
+
|
476 |
+
|
477 |
+
def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
|
478 |
+
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
|
479 |
+
|
480 |
+
|
481 |
+
class RMSNorm(torch.nn.Module):
|
482 |
+
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
483 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
484 |
+
super().__init__()
|
485 |
+
self.eps = eps
|
486 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
487 |
+
self.register_parameter("bias", None)
|
488 |
+
self.reset_parameters()
|
489 |
+
|
490 |
+
def reset_parameters(self):
|
491 |
+
torch.nn.init.ones_(self.weight)
|
492 |
+
|
493 |
+
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
494 |
+
return rms_norm_fn(
|
495 |
+
x,
|
496 |
+
self.weight,
|
497 |
+
self.bias,
|
498 |
+
residual=residual,
|
499 |
+
eps=self.eps,
|
500 |
+
prenorm=prenorm,
|
501 |
+
residual_in_fp32=residual_in_fp32,
|
502 |
+
)
|
503 |
+
|
504 |
+
|
505 |
+
class LayerNormLinearFn(torch.autograd.Function):
|
506 |
+
@staticmethod
|
507 |
+
@custom_fwd
|
508 |
+
def forward(
|
509 |
+
ctx,
|
510 |
+
x,
|
511 |
+
norm_weight,
|
512 |
+
norm_bias,
|
513 |
+
linear_weight,
|
514 |
+
linear_bias,
|
515 |
+
residual=None,
|
516 |
+
eps=1e-6,
|
517 |
+
prenorm=False,
|
518 |
+
residual_in_fp32=False,
|
519 |
+
is_rms_norm=False,
|
520 |
+
):
|
521 |
+
x_shape_og = x.shape
|
522 |
+
# reshape input data into 2D tensor
|
523 |
+
x = x.reshape(-1, x.shape[-1])
|
524 |
+
if x.stride(-1) != 1:
|
525 |
+
x = x.contiguous()
|
526 |
+
if residual is not None:
|
527 |
+
assert residual.shape == x_shape_og
|
528 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
529 |
+
if residual.stride(-1) != 1:
|
530 |
+
residual = residual.contiguous()
|
531 |
+
norm_weight = norm_weight.contiguous()
|
532 |
+
if norm_bias is not None:
|
533 |
+
norm_bias = norm_bias.contiguous()
|
534 |
+
residual_dtype = (
|
535 |
+
residual.dtype
|
536 |
+
if residual is not None
|
537 |
+
else (torch.float32 if residual_in_fp32 else None)
|
538 |
+
)
|
539 |
+
y, mean, rstd, residual_out = _layer_norm_fwd(
|
540 |
+
x,
|
541 |
+
norm_weight,
|
542 |
+
norm_bias,
|
543 |
+
eps,
|
544 |
+
residual,
|
545 |
+
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
|
546 |
+
residual_dtype=residual_dtype,
|
547 |
+
is_rms_norm=is_rms_norm,
|
548 |
+
)
|
549 |
+
y = y.reshape(x_shape_og)
|
550 |
+
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
551 |
+
linear_weight = linear_weight.to(dtype)
|
552 |
+
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
553 |
+
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
554 |
+
# We don't store y, will be recomputed in the backward pass to save memory
|
555 |
+
ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
|
556 |
+
ctx.x_shape_og = x_shape_og
|
557 |
+
ctx.eps = eps
|
558 |
+
ctx.is_rms_norm = is_rms_norm
|
559 |
+
ctx.has_residual = residual is not None
|
560 |
+
ctx.prenorm = prenorm
|
561 |
+
ctx.x_dtype = x.dtype
|
562 |
+
ctx.linear_bias_is_none = linear_bias is None
|
563 |
+
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
564 |
+
|
565 |
+
@staticmethod
|
566 |
+
@custom_bwd
|
567 |
+
def backward(ctx, dout, *args):
|
568 |
+
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
569 |
+
dout = dout.reshape(-1, dout.shape[-1])
|
570 |
+
dy = F.linear(dout, linear_weight.t())
|
571 |
+
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
572 |
+
if dy.stride(-1) != 1:
|
573 |
+
dy = dy.contiguous()
|
574 |
+
assert dy.shape == x.shape
|
575 |
+
if ctx.prenorm:
|
576 |
+
dresidual = args[0]
|
577 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
578 |
+
if dresidual.stride(-1) != 1:
|
579 |
+
dresidual = dresidual.contiguous()
|
580 |
+
assert dresidual.shape == x.shape
|
581 |
+
else:
|
582 |
+
dresidual = None
|
583 |
+
dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
|
584 |
+
dy,
|
585 |
+
x,
|
586 |
+
norm_weight,
|
587 |
+
norm_bias,
|
588 |
+
ctx.eps,
|
589 |
+
mean,
|
590 |
+
rstd,
|
591 |
+
dresidual,
|
592 |
+
ctx.has_residual,
|
593 |
+
ctx.is_rms_norm,
|
594 |
+
x_dtype=ctx.x_dtype,
|
595 |
+
recompute_output=True,
|
596 |
+
)
|
597 |
+
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
598 |
+
return (
|
599 |
+
dx.reshape(ctx.x_shape_og),
|
600 |
+
dnorm_weight,
|
601 |
+
dnorm_bias,
|
602 |
+
dlinear_weight,
|
603 |
+
dlinear_bias,
|
604 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
605 |
+
None,
|
606 |
+
None,
|
607 |
+
None,
|
608 |
+
None,
|
609 |
+
)
|
610 |
+
|
611 |
+
|
612 |
+
def layer_norm_linear_fn(
|
613 |
+
x,
|
614 |
+
norm_weight,
|
615 |
+
norm_bias,
|
616 |
+
linear_weight,
|
617 |
+
linear_bias,
|
618 |
+
residual=None,
|
619 |
+
eps=1e-6,
|
620 |
+
prenorm=False,
|
621 |
+
residual_in_fp32=False,
|
622 |
+
is_rms_norm=False,
|
623 |
+
):
|
624 |
+
return LayerNormLinearFn.apply(
|
625 |
+
x,
|
626 |
+
norm_weight,
|
627 |
+
norm_bias,
|
628 |
+
linear_weight,
|
629 |
+
linear_bias,
|
630 |
+
residual,
|
631 |
+
eps,
|
632 |
+
prenorm,
|
633 |
+
residual_in_fp32,
|
634 |
+
is_rms_norm,
|
635 |
+
)
|
mamba_ssm/ops/triton/selective_state_update.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import triton
|
11 |
+
import triton.language as tl
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
|
16 |
+
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
17 |
+
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
18 |
+
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
19 |
+
@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
|
20 |
+
@triton.jit
|
21 |
+
def _selective_scan_update_kernel(
|
22 |
+
# Pointers to matrices
|
23 |
+
state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
|
24 |
+
# Matrix dimensions
|
25 |
+
batch, nheads, dim, dstate, nheads_ngroups_ratio,
|
26 |
+
# Strides
|
27 |
+
stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,
|
28 |
+
stride_x_batch, stride_x_head, stride_x_dim,
|
29 |
+
stride_dt_batch, stride_dt_head, stride_dt_dim,
|
30 |
+
stride_dt_bias_head, stride_dt_bias_dim,
|
31 |
+
stride_A_head, stride_A_dim, stride_A_dstate,
|
32 |
+
stride_B_batch, stride_B_group, stride_B_dstate,
|
33 |
+
stride_C_batch, stride_C_group, stride_C_dstate,
|
34 |
+
stride_D_head, stride_D_dim,
|
35 |
+
stride_z_batch, stride_z_head, stride_z_dim,
|
36 |
+
stride_out_batch, stride_out_head, stride_out_dim,
|
37 |
+
# Meta-parameters
|
38 |
+
DT_SOFTPLUS: tl.constexpr,
|
39 |
+
TIE_HDIM: tl.constexpr,
|
40 |
+
BLOCK_SIZE_M: tl.constexpr,
|
41 |
+
HAS_DT_BIAS: tl.constexpr,
|
42 |
+
HAS_D: tl.constexpr,
|
43 |
+
HAS_Z: tl.constexpr,
|
44 |
+
BLOCK_SIZE_DSTATE: tl.constexpr,
|
45 |
+
):
|
46 |
+
pid_m = tl.program_id(axis=0)
|
47 |
+
pid_b = tl.program_id(axis=1)
|
48 |
+
pid_h = tl.program_id(axis=2)
|
49 |
+
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
50 |
+
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
51 |
+
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
52 |
+
if HAS_DT_BIAS:
|
53 |
+
dt_bias_ptr += pid_h * stride_dt_bias_head
|
54 |
+
A_ptr += pid_h * stride_A_head
|
55 |
+
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
56 |
+
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
57 |
+
if HAS_Z:
|
58 |
+
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
59 |
+
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
60 |
+
|
61 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
62 |
+
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
63 |
+
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
|
64 |
+
x_ptrs = x_ptr + offs_m * stride_x_dim
|
65 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
66 |
+
if HAS_DT_BIAS:
|
67 |
+
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
68 |
+
if HAS_D:
|
69 |
+
D_ptr += pid_h * stride_D_head
|
70 |
+
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
|
71 |
+
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
72 |
+
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
73 |
+
if HAS_D:
|
74 |
+
D_ptrs = D_ptr + offs_m * stride_D_dim
|
75 |
+
if HAS_Z:
|
76 |
+
z_ptrs = z_ptr + offs_m * stride_z_dim
|
77 |
+
out_ptrs = out_ptr + offs_m * stride_out_dim
|
78 |
+
|
79 |
+
state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
|
80 |
+
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
81 |
+
if not TIE_HDIM:
|
82 |
+
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
83 |
+
if HAS_DT_BIAS:
|
84 |
+
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
85 |
+
if DT_SOFTPLUS:
|
86 |
+
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
|
87 |
+
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
|
88 |
+
dA = tl.exp(A * dt[:, None])
|
89 |
+
else:
|
90 |
+
dt = tl.load(dt_ptr).to(tl.float32)
|
91 |
+
if HAS_DT_BIAS:
|
92 |
+
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
93 |
+
if DT_SOFTPLUS:
|
94 |
+
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
|
95 |
+
A = tl.load(A_ptr).to(tl.float32)
|
96 |
+
dA = tl.exp(A * dt) # scalar, not a matrix
|
97 |
+
|
98 |
+
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
99 |
+
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
100 |
+
if HAS_D:
|
101 |
+
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
102 |
+
if HAS_Z:
|
103 |
+
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
104 |
+
|
105 |
+
if not TIE_HDIM:
|
106 |
+
dB = B[None, :] * dt[:, None]
|
107 |
+
else:
|
108 |
+
dB = B * dt # vector of size (dstate,)
|
109 |
+
state = state * dA + dB * x[:, None]
|
110 |
+
tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
|
111 |
+
out = tl.sum(state * C[None, :], axis=1)
|
112 |
+
if HAS_D:
|
113 |
+
out += x * D
|
114 |
+
if HAS_Z:
|
115 |
+
out *= z * tl.sigmoid(z)
|
116 |
+
tl.store(out_ptrs, out, mask=offs_m < dim)
|
117 |
+
|
118 |
+
|
119 |
+
def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
|
120 |
+
"""
|
121 |
+
Argument:
|
122 |
+
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
123 |
+
x: (batch, dim) or (batch, nheads, dim)
|
124 |
+
dt: (batch, dim) or (batch, nheads, dim)
|
125 |
+
A: (dim, dstate) or (nheads, dim, dstate)
|
126 |
+
B: (batch, dstate) or (batch, ngroups, dstate)
|
127 |
+
C: (batch, dstate) or (batch, ngroups, dstate)
|
128 |
+
D: (dim,) or (nheads, dim)
|
129 |
+
z: (batch, dim) or (batch, nheads, dim)
|
130 |
+
dt_bias: (dim,) or (nheads, dim)
|
131 |
+
Return:
|
132 |
+
out: (batch, dim) or (batch, nheads, dim)
|
133 |
+
"""
|
134 |
+
has_heads = state.dim() > 3
|
135 |
+
if state.dim() == 3:
|
136 |
+
state = state.unsqueeze(1)
|
137 |
+
if x.dim() == 2:
|
138 |
+
x = x.unsqueeze(1)
|
139 |
+
if dt.dim() == 2:
|
140 |
+
dt = dt.unsqueeze(1)
|
141 |
+
if A.dim() == 2:
|
142 |
+
A = A.unsqueeze(0)
|
143 |
+
if B.dim() == 2:
|
144 |
+
B = B.unsqueeze(1)
|
145 |
+
if C.dim() == 2:
|
146 |
+
C = C.unsqueeze(1)
|
147 |
+
if D is not None and D.dim() == 1:
|
148 |
+
D = D.unsqueeze(0)
|
149 |
+
if z is not None and z.dim() == 2:
|
150 |
+
z = z.unsqueeze(1)
|
151 |
+
if dt_bias is not None and dt_bias.dim() == 1:
|
152 |
+
dt_bias = dt_bias.unsqueeze(0)
|
153 |
+
batch, nheads, dim, dstate = state.shape
|
154 |
+
assert x.shape == (batch, nheads, dim)
|
155 |
+
assert dt.shape == x.shape
|
156 |
+
assert A.shape == (nheads, dim, dstate)
|
157 |
+
ngroups = B.shape[1]
|
158 |
+
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
159 |
+
assert B.shape == (batch, ngroups, dstate)
|
160 |
+
assert C.shape == B.shape
|
161 |
+
if D is not None:
|
162 |
+
assert D.shape == (nheads, dim)
|
163 |
+
if z is not None:
|
164 |
+
assert z.shape == x.shape
|
165 |
+
if dt_bias is not None:
|
166 |
+
assert dt_bias.shape == (nheads, dim)
|
167 |
+
out = torch.empty_like(x)
|
168 |
+
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
169 |
+
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))
|
170 |
+
# We don't want autotune since it will overwrite the state
|
171 |
+
# We instead tune by hand.
|
172 |
+
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
|
173 |
+
else ((16, 4) if dstate <= 32 else
|
174 |
+
((8, 4) if dstate <= 64 else
|
175 |
+
((4, 4) if dstate <= 128 else
|
176 |
+
((4, 8))))))
|
177 |
+
tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0
|
178 |
+
with torch.cuda.device(x.device.index):
|
179 |
+
_selective_scan_update_kernel[grid](
|
180 |
+
state, x, dt, dt_bias, A, B, C, D, z, out,
|
181 |
+
batch, nheads, dim, dstate, nheads // ngroups,
|
182 |
+
state.stride(0), state.stride(1), state.stride(2), state.stride(3),
|
183 |
+
x.stride(0), x.stride(1), x.stride(2),
|
184 |
+
dt.stride(0), dt.stride(1), dt.stride(2),
|
185 |
+
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
|
186 |
+
A.stride(0), A.stride(1), A.stride(2),
|
187 |
+
B.stride(0), B.stride(1), B.stride(2),
|
188 |
+
C.stride(0), C.stride(1), C.stride(2),
|
189 |
+
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
190 |
+
z_strides[0], z_strides[1], z_strides[2],
|
191 |
+
out.stride(0), out.stride(1), out.stride(2),
|
192 |
+
dt_softplus,
|
193 |
+
tie_hdim,
|
194 |
+
BLOCK_SIZE_M,
|
195 |
+
num_warps=num_warps,
|
196 |
+
)
|
197 |
+
if not has_heads:
|
198 |
+
out = out.squeeze(1)
|
199 |
+
return out
|
200 |
+
|
201 |
+
|
202 |
+
def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
|
203 |
+
"""
|
204 |
+
Argument:
|
205 |
+
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
206 |
+
x: (batch, dim) or (batch, nheads, dim)
|
207 |
+
dt: (batch, dim) or (batch, nheads, dim)
|
208 |
+
A: (dim, dstate) or (nheads, dim, dstate)
|
209 |
+
B: (batch, dstate) or (batch, ngroups, dstate)
|
210 |
+
C: (batch, dstate) or (batch, ngroups, dstate)
|
211 |
+
D: (dim,) or (nheads, dim)
|
212 |
+
z: (batch, dim) or (batch, nheads, dim)
|
213 |
+
dt_bias: (dim,) or (nheads, dim)
|
214 |
+
Return:
|
215 |
+
out: (batch, dim) or (batch, nheads, dim)
|
216 |
+
"""
|
217 |
+
has_heads = state.dim() > 3
|
218 |
+
if state.dim() == 3:
|
219 |
+
state = state.unsqueeze(1)
|
220 |
+
if x.dim() == 2:
|
221 |
+
x = x.unsqueeze(1)
|
222 |
+
if dt.dim() == 2:
|
223 |
+
dt = dt.unsqueeze(1)
|
224 |
+
if A.dim() == 2:
|
225 |
+
A = A.unsqueeze(0)
|
226 |
+
if B.dim() == 2:
|
227 |
+
B = B.unsqueeze(1)
|
228 |
+
if C.dim() == 2:
|
229 |
+
C = C.unsqueeze(1)
|
230 |
+
if D is not None and D.dim() == 1:
|
231 |
+
D = D.unsqueeze(0)
|
232 |
+
if z is not None and z.dim() == 2:
|
233 |
+
z = z.unsqueeze(1)
|
234 |
+
if dt_bias is not None and dt_bias.dim() == 1:
|
235 |
+
dt_bias = dt_bias.unsqueeze(0)
|
236 |
+
batch, nheads, dim, dstate = state.shape
|
237 |
+
assert x.shape == (batch, nheads, dim)
|
238 |
+
assert dt.shape == x.shape
|
239 |
+
assert A.shape == (nheads, dim, dstate)
|
240 |
+
ngroups = B.shape[1]
|
241 |
+
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
242 |
+
assert B.shape == (batch, ngroups, dstate)
|
243 |
+
assert C.shape == B.shape
|
244 |
+
if D is not None:
|
245 |
+
assert D.shape == (nheads, dim)
|
246 |
+
if z is not None:
|
247 |
+
assert z.shape == x.shape
|
248 |
+
if dt_bias is not None:
|
249 |
+
assert dt_bias.shape == (nheads, dim)
|
250 |
+
dt = dt + dt_bias
|
251 |
+
dt = F.softplus(dt) if dt_softplus else dt
|
252 |
+
dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate)
|
253 |
+
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
254 |
+
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
255 |
+
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
|
256 |
+
state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
|
257 |
+
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
258 |
+
if D is not None:
|
259 |
+
out += (x * D).to(out.dtype)
|
260 |
+
out = (out if z is None else out * F.silu(z)).to(x.dtype)
|
261 |
+
if not has_heads:
|
262 |
+
out = out.squeeze(1)
|
263 |
+
return out
|
mamba_ssm/utils/__init__.py
ADDED
File without changes
|
mamba_ssm/utils/generation.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
2 |
+
import gc
|
3 |
+
import time
|
4 |
+
from collections import namedtuple
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
from functools import partial
|
7 |
+
from typing import Callable, Optional, Sequence, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
from torch import Tensor
|
13 |
+
from torch.profiler import ProfilerActivity, profile, record_function
|
14 |
+
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class InferenceParams:
|
19 |
+
"""Inference parameters that are passed to the main model in order
|
20 |
+
to efficienly calculate and store the context during inference."""
|
21 |
+
|
22 |
+
max_seqlen: int
|
23 |
+
max_batch_size: int
|
24 |
+
seqlen_offset: int = 0
|
25 |
+
batch_size_offset: int = 0
|
26 |
+
key_value_memory_dict: dict = field(default_factory=dict)
|
27 |
+
lengths_per_sample: Optional[Tensor] = None
|
28 |
+
|
29 |
+
def reset(self, max_seqlen, max_batch_size):
|
30 |
+
self.max_seqlen = max_seqlen
|
31 |
+
self.max_batch_size = max_batch_size
|
32 |
+
self.seqlen_offset = 0
|
33 |
+
if self.lengths_per_sample is not None:
|
34 |
+
self.lengths_per_sample.zero_()
|
35 |
+
|
36 |
+
|
37 |
+
def modify_logits_for_min_p_filtering(logits, min_p):
|
38 |
+
"""Set the logits for none min_p values to -inf. Done in-place."""
|
39 |
+
if min_p <= 0.0 or min_p >= 1.0:
|
40 |
+
return
|
41 |
+
indices_to_remove = logits < min_p
|
42 |
+
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
43 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
44 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
|
45 |
+
def modify_logits_for_top_k_filtering(logits, top_k):
|
46 |
+
"""Set the logits for none top-k values to -inf. Done in-place."""
|
47 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
48 |
+
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
49 |
+
|
50 |
+
|
51 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
52 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
53 |
+
def modify_logits_for_top_p_filtering(logits, top_p):
|
54 |
+
"""Set the logits for none top-p values to -inf. Done in-place."""
|
55 |
+
if top_p <= 0.0 or top_p >= 1.0:
|
56 |
+
return
|
57 |
+
# First sort and calculate cumulative sum of probabilities.
|
58 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
59 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
60 |
+
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
61 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
62 |
+
# scatter sorted tensors to original indexing
|
63 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
64 |
+
1, sorted_indices, sorted_indices_to_remove
|
65 |
+
)
|
66 |
+
logits.masked_fill_(indices_to_remove, float("-inf"))
|
67 |
+
|
68 |
+
|
69 |
+
def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
|
70 |
+
"""Apply repetition penalty. See https://arxiv.org/abs/1909.05858
|
71 |
+
logits: (batch_size, vocab_size)
|
72 |
+
prev_output_tokens: (batch_size, seq_len)
|
73 |
+
"""
|
74 |
+
if repetition_penalty == 1.0:
|
75 |
+
return logits
|
76 |
+
score = torch.gather(logits, 1, prev_output_tokens)
|
77 |
+
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
78 |
+
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
79 |
+
logits.scatter_(1, prev_output_tokens, score)
|
80 |
+
return logits
|
81 |
+
|
82 |
+
|
83 |
+
def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
|
84 |
+
"""Sample from top-k logits.
|
85 |
+
Arguments:
|
86 |
+
logits: Tensor of shape (batch_size, vocab_size)
|
87 |
+
"""
|
88 |
+
if top_k == 1: # Short-circuit for greedy decoding
|
89 |
+
return logits.argmax(dim=-1)
|
90 |
+
else:
|
91 |
+
if top_p > 0.0:
|
92 |
+
assert top_p <= 1.0, "top-p should be in (0, 1]."
|
93 |
+
if top_k > 0:
|
94 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
95 |
+
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
96 |
+
if temperature != 1.0:
|
97 |
+
logits_top /= temperature
|
98 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
99 |
+
return indices[
|
100 |
+
torch.arange(indices.shape[0], device=indices.device),
|
101 |
+
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
|
102 |
+
]
|
103 |
+
else:
|
104 |
+
if min_p > 0.0:
|
105 |
+
logits_top = logits.clone()
|
106 |
+
max_prob = logits_top[..., 0].item()
|
107 |
+
min_prob = max_prob * min_p
|
108 |
+
modify_logits_for_min_p_filtering(logits_top, min_p)
|
109 |
+
if temperature != 1.0:
|
110 |
+
logits_top /= temperature
|
111 |
+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
112 |
+
# Clone so that when we modify for top_p we don't change the original logits
|
113 |
+
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
|
114 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
115 |
+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
|
116 |
+
dim=-1
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
@torch.inference_mode()
|
121 |
+
def decode(
|
122 |
+
input_ids,
|
123 |
+
model,
|
124 |
+
max_length,
|
125 |
+
top_k=1,
|
126 |
+
top_p=0.0,
|
127 |
+
min_p=0.0,
|
128 |
+
temperature=1.0,
|
129 |
+
repetition_penalty=1.0,
|
130 |
+
eos_token_id=None,
|
131 |
+
teacher_outputs=None,
|
132 |
+
vocab_size=None,
|
133 |
+
cg=False,
|
134 |
+
enable_timing=False,
|
135 |
+
streamer: Optional[TextStreamer] = None
|
136 |
+
):
|
137 |
+
"""Decoding, either greedy or with top-k or top-p sampling.
|
138 |
+
If top-k = 0, don't limit the number of candidates (pure sampling).
|
139 |
+
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
140 |
+
then top-p.
|
141 |
+
We assume that all sequences in the same batch have the same length.
|
142 |
+
|
143 |
+
Arguments:
|
144 |
+
input_ids: (batch, seq_len)
|
145 |
+
max_length: int
|
146 |
+
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
147 |
+
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
148 |
+
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
149 |
+
sequences: (batch, max_length)
|
150 |
+
scores: tuples of (batch, vocab_size)
|
151 |
+
"""
|
152 |
+
if streamer is not None:
|
153 |
+
streamer.put(input_ids.cpu())
|
154 |
+
|
155 |
+
batch_size, seqlen_og = input_ids.shape
|
156 |
+
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
|
157 |
+
if cg:
|
158 |
+
if not hasattr(model, "_decoding_cache"):
|
159 |
+
model._decoding_cache = None
|
160 |
+
model._decoding_cache = update_graph_cache(
|
161 |
+
model,
|
162 |
+
model._decoding_cache,
|
163 |
+
batch_size,
|
164 |
+
seqlen_og,
|
165 |
+
max_length,
|
166 |
+
)
|
167 |
+
inference_params = model._decoding_cache.inference_params
|
168 |
+
inference_params.reset(max_length, batch_size)
|
169 |
+
else:
|
170 |
+
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
171 |
+
|
172 |
+
def get_logits(input_ids, inference_params):
|
173 |
+
decoding = inference_params.seqlen_offset > 0
|
174 |
+
if decoding:
|
175 |
+
position_ids = torch.full(
|
176 |
+
(batch_size, 1),
|
177 |
+
inference_params.seqlen_offset,
|
178 |
+
dtype=torch.long,
|
179 |
+
device=input_ids.device,
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
position_ids = None
|
183 |
+
if not cg or not decoding:
|
184 |
+
logits = model(
|
185 |
+
input_ids,
|
186 |
+
position_ids=position_ids,
|
187 |
+
inference_params=inference_params,
|
188 |
+
num_last_tokens=1,
|
189 |
+
).logits.squeeze(dim=1)
|
190 |
+
else:
|
191 |
+
logits = model._decoding_cache.run(
|
192 |
+
input_ids, position_ids, inference_params.seqlen_offset
|
193 |
+
).squeeze(dim=1)
|
194 |
+
return logits[..., :vocab_size] if vocab_size is not None else logits
|
195 |
+
|
196 |
+
def sample_tokens(logits, inference_params):
|
197 |
+
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
|
198 |
+
token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
|
199 |
+
else:
|
200 |
+
token = teacher_outputs[:, inference_params.seqlen_offset]
|
201 |
+
# return rearrange(token, "b -> b 1")
|
202 |
+
return token.unsqueeze(1)
|
203 |
+
|
204 |
+
def should_stop(current_token, inference_params):
|
205 |
+
if inference_params.seqlen_offset == 0:
|
206 |
+
return False
|
207 |
+
if eos_token_id is not None and (current_token == eos_token_id).all():
|
208 |
+
return True
|
209 |
+
if inference_params.seqlen_offset >= max_length - 1:
|
210 |
+
return True
|
211 |
+
return False
|
212 |
+
|
213 |
+
start = torch.cuda.Event(enable_timing=enable_timing)
|
214 |
+
end = torch.cuda.Event(enable_timing=enable_timing)
|
215 |
+
|
216 |
+
if enable_timing:
|
217 |
+
start.record()
|
218 |
+
scores, sequences = [], [input_ids]
|
219 |
+
sequences_cat = input_ids
|
220 |
+
while not should_stop(sequences[-1], inference_params):
|
221 |
+
scores.append(get_logits(sequences[-1], inference_params))
|
222 |
+
inference_params.seqlen_offset += sequences[-1].shape[1]
|
223 |
+
if repetition_penalty == 1.0:
|
224 |
+
sampled_tokens = sample_tokens(scores[-1], inference_params)
|
225 |
+
else:
|
226 |
+
logits = modify_logit_for_repetition_penalty(
|
227 |
+
scores[-1].clone(), sequences_cat, repetition_penalty
|
228 |
+
)
|
229 |
+
sampled_tokens = sample_tokens(logits, inference_params)
|
230 |
+
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
|
231 |
+
sequences.append(sampled_tokens)
|
232 |
+
if streamer is not None:
|
233 |
+
streamer.put(sampled_tokens.cpu())
|
234 |
+
if streamer is not None:
|
235 |
+
streamer.end()
|
236 |
+
if enable_timing:
|
237 |
+
end.record()
|
238 |
+
torch.cuda.synchronize()
|
239 |
+
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
|
240 |
+
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
241 |
+
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
|
242 |
+
|
243 |
+
|
244 |
+
class GenerationMixin:
|
245 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
246 |
+
raise NotImplementedError
|
247 |
+
|
248 |
+
def generate(
|
249 |
+
self,
|
250 |
+
input_ids,
|
251 |
+
max_length,
|
252 |
+
top_k=1,
|
253 |
+
top_p=0.0,
|
254 |
+
min_p=0.0,
|
255 |
+
temperature=1.0,
|
256 |
+
return_dict_in_generate=False,
|
257 |
+
output_scores=False,
|
258 |
+
**kwargs,
|
259 |
+
):
|
260 |
+
output = decode(
|
261 |
+
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
|
262 |
+
)
|
263 |
+
if not output_scores:
|
264 |
+
output.scores = None
|
265 |
+
return output if return_dict_in_generate else output.sequences
|
266 |
+
|
267 |
+
|
268 |
+
@dataclass
|
269 |
+
class DecodingCGCache:
|
270 |
+
max_batch_size: int = 0
|
271 |
+
max_seqlen: int = 0
|
272 |
+
device = None
|
273 |
+
dtype = None
|
274 |
+
callables: dict = field(default_factory=dict)
|
275 |
+
mempool = None
|
276 |
+
inference_params: Optional[InferenceParams] = None
|
277 |
+
run: Optional[Callable] = None
|
278 |
+
|
279 |
+
|
280 |
+
@torch.inference_mode()
|
281 |
+
def update_graph_cache(
|
282 |
+
model,
|
283 |
+
cache,
|
284 |
+
batch_size,
|
285 |
+
seqlen_og,
|
286 |
+
max_seqlen,
|
287 |
+
decoding_seqlens=(1,),
|
288 |
+
dtype=None,
|
289 |
+
n_warmups=2,
|
290 |
+
):
|
291 |
+
if cache is None:
|
292 |
+
cache = DecodingCGCache()
|
293 |
+
param_example = next(iter(model.parameters()))
|
294 |
+
device = param_example.device
|
295 |
+
if dtype is None:
|
296 |
+
dtype = param_example.dtype
|
297 |
+
if (
|
298 |
+
(device, dtype) != (cache.device, cache.dtype)
|
299 |
+
or batch_size > cache.max_batch_size
|
300 |
+
or max_seqlen > cache.max_seqlen
|
301 |
+
): # Invalidate the cache
|
302 |
+
cache.callables = {}
|
303 |
+
cache.mempool = None
|
304 |
+
cache.inference_params = None
|
305 |
+
gc.collect()
|
306 |
+
cache.device, cache.dtype = device, dtype
|
307 |
+
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
308 |
+
assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
|
309 |
+
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
|
310 |
+
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
311 |
+
cache.inference_params = InferenceParams(
|
312 |
+
max_seqlen=max_seqlen,
|
313 |
+
max_batch_size=batch_size,
|
314 |
+
seqlen_offset=seqlen_og,
|
315 |
+
key_value_memory_dict=inf_cache,
|
316 |
+
lengths_per_sample=lengths_per_sample,
|
317 |
+
)
|
318 |
+
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
319 |
+
for decoding_seqlen in decoding_seqlens:
|
320 |
+
if (batch_size, decoding_seqlen) not in cache.callables:
|
321 |
+
cache.callables[batch_size, decoding_seqlen] = capture_graph(
|
322 |
+
model,
|
323 |
+
cache.inference_params,
|
324 |
+
batch_size,
|
325 |
+
max_seqlen,
|
326 |
+
decoding_seqlen=decoding_seqlen,
|
327 |
+
mempool=cache.mempool,
|
328 |
+
n_warmups=n_warmups,
|
329 |
+
)
|
330 |
+
|
331 |
+
def dispatch(input_ids, position_ids, seqlen):
|
332 |
+
batch_size, decoding_seqlen = input_ids.shape[:2]
|
333 |
+
return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
|
334 |
+
|
335 |
+
cache.run = dispatch
|
336 |
+
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
|
337 |
+
return cache
|
338 |
+
|
339 |
+
|
340 |
+
def capture_graph(
|
341 |
+
model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
|
342 |
+
):
|
343 |
+
device = next(iter(model.parameters())).device
|
344 |
+
input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
345 |
+
position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
346 |
+
seqlen_offset_og = inference_params.seqlen_offset
|
347 |
+
inference_params.seqlen_offset = max_seqlen - decoding_seqlen
|
348 |
+
inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
|
349 |
+
|
350 |
+
# Warmup before capture
|
351 |
+
s = torch.cuda.Stream()
|
352 |
+
s.wait_stream(torch.cuda.current_stream())
|
353 |
+
with torch.cuda.stream(s):
|
354 |
+
for _ in range(n_warmups):
|
355 |
+
logits = model(
|
356 |
+
input_ids,
|
357 |
+
position_ids=position_ids,
|
358 |
+
inference_params=inference_params,
|
359 |
+
num_last_tokens=decoding_seqlen,
|
360 |
+
).logits
|
361 |
+
s.synchronize()
|
362 |
+
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
|
363 |
+
# which requires that graph launch and non-captured launch to not overlap (I think,
|
364 |
+
# that's how I interpret the documentation). I'm not sure if this is required.
|
365 |
+
if torch.distributed.is_initialized():
|
366 |
+
torch.distributed.barrier()
|
367 |
+
torch.cuda.current_stream().wait_stream(s)
|
368 |
+
# Captures the graph
|
369 |
+
# To allow capture, automatically sets a side stream as the current stream in the context
|
370 |
+
graph = torch.cuda.CUDAGraph()
|
371 |
+
with torch.cuda.graph(graph, pool=mempool):
|
372 |
+
logits = model(
|
373 |
+
input_ids,
|
374 |
+
position_ids=position_ids,
|
375 |
+
inference_params=inference_params,
|
376 |
+
num_last_tokens=decoding_seqlen,
|
377 |
+
).logits
|
378 |
+
|
379 |
+
def run(new_input_ids, new_position_ids, seqlen):
|
380 |
+
inference_params.lengths_per_sample[:] = seqlen
|
381 |
+
input_ids.copy_(new_input_ids)
|
382 |
+
position_ids.copy_(new_position_ids)
|
383 |
+
graph.replay()
|
384 |
+
return logits.clone()
|
385 |
+
|
386 |
+
inference_params.seqlen_offset = seqlen_offset_og
|
387 |
+
return run
|
mamba_ssm/utils/hf.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
6 |
+
from transformers.utils.hub import cached_file
|
7 |
+
|
8 |
+
|
9 |
+
def load_config_hf(model_name):
|
10 |
+
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
|
11 |
+
return json.load(open(resolved_archive_file))
|
12 |
+
|
13 |
+
|
14 |
+
def load_state_dict_hf(model_name, device=None, dtype=None):
|
15 |
+
# If not fp32, then we don't want to load directly to the GPU
|
16 |
+
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
|
17 |
+
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
18 |
+
return torch.load(resolved_archive_file, map_location=mapped_device)
|
19 |
+
# Convert dtype before moving to GPU to save memory
|
20 |
+
if dtype is not None:
|
21 |
+
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
|
22 |
+
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
23 |
+
return state_dict
|
models/codec_module.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/models/generator.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
from .lsigmoid import LearnableSigmoid2D
|
7 |
+
|
8 |
+
def get_padding(kernel_size, dilation=1):
|
9 |
+
"""
|
10 |
+
Calculate the padding size for a convolutional layer.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
- kernel_size (int): Size of the convolutional kernel.
|
14 |
+
- dilation (int, optional): Dilation rate of the convolution. Defaults to 1.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
- int: Calculated padding size.
|
18 |
+
"""
|
19 |
+
return int((kernel_size * dilation - dilation) / 2)
|
20 |
+
|
21 |
+
def get_padding_2d(kernel_size, dilation=(1, 1)):
|
22 |
+
"""
|
23 |
+
Calculate the padding size for a 2D convolutional layer.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
- kernel_size (tuple): Size of the convolutional kernel (height, width).
|
27 |
+
- dilation (tuple, optional): Dilation rate of the convolution (height, width). Defaults to (1, 1).
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
- tuple: Calculated padding size (height, width).
|
31 |
+
"""
|
32 |
+
return (int((kernel_size[0] * dilation[0] - dilation[0]) / 2),
|
33 |
+
int((kernel_size[1] * dilation[1] - dilation[1]) / 2))
|
34 |
+
|
35 |
+
class DenseBlock(nn.Module):
|
36 |
+
"""
|
37 |
+
DenseBlock module consisting of multiple convolutional layers with dilation.
|
38 |
+
"""
|
39 |
+
def __init__(self, cfg, kernel_size=(3, 3), depth=4):
|
40 |
+
super(DenseBlock, self).__init__()
|
41 |
+
self.cfg = cfg
|
42 |
+
self.depth = depth
|
43 |
+
self.dense_block = nn.ModuleList()
|
44 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
45 |
+
|
46 |
+
for i in range(depth):
|
47 |
+
dil = 2 ** i
|
48 |
+
dense_conv = nn.Sequential(
|
49 |
+
nn.Conv2d(self.hid_feature * (i + 1), self.hid_feature, kernel_size,
|
50 |
+
dilation=(dil, 1), padding=get_padding_2d(kernel_size, (dil, 1))),
|
51 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
52 |
+
nn.PReLU(self.hid_feature)
|
53 |
+
)
|
54 |
+
self.dense_block.append(dense_conv)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
"""
|
58 |
+
Forward pass for the DenseBlock module.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
- x (torch.Tensor): Input tensor.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
- torch.Tensor: Output tensor after processing through the dense block.
|
65 |
+
"""
|
66 |
+
skip = x
|
67 |
+
for i in range(self.depth):
|
68 |
+
x = self.dense_block[i](skip)
|
69 |
+
skip = torch.cat([x, skip], dim=1)
|
70 |
+
return x
|
71 |
+
|
72 |
+
class DenseEncoder(nn.Module):
|
73 |
+
"""
|
74 |
+
DenseEncoder module consisting of initial convolution, dense block, and a final convolution.
|
75 |
+
"""
|
76 |
+
def __init__(self, cfg):
|
77 |
+
super(DenseEncoder, self).__init__()
|
78 |
+
self.cfg = cfg
|
79 |
+
self.input_channel = cfg['model_cfg']['input_channel']
|
80 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
81 |
+
|
82 |
+
self.dense_conv_1 = nn.Sequential(
|
83 |
+
nn.Conv2d(self.input_channel, self.hid_feature, (1, 1)),
|
84 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
85 |
+
nn.PReLU(self.hid_feature)
|
86 |
+
)
|
87 |
+
|
88 |
+
self.dense_block = DenseBlock(cfg, depth=4)
|
89 |
+
|
90 |
+
self.dense_conv_2 = nn.Sequential(
|
91 |
+
nn.Conv2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
|
92 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
93 |
+
nn.PReLU(self.hid_feature)
|
94 |
+
)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
"""
|
98 |
+
Forward pass for the DenseEncoder module.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
- x (torch.Tensor): Input tensor.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
- torch.Tensor: Encoded tensor.
|
105 |
+
"""
|
106 |
+
x = self.dense_conv_1(x) # [batch, hid_feature, time, freq]
|
107 |
+
x = self.dense_block(x) # [batch, hid_feature, time, freq]
|
108 |
+
x = self.dense_conv_2(x) # [batch, hid_feature, time, freq//2]
|
109 |
+
return x
|
110 |
+
|
111 |
+
class MagDecoder(nn.Module):
|
112 |
+
"""
|
113 |
+
MagDecoder module for decoding magnitude information.
|
114 |
+
"""
|
115 |
+
def __init__(self, cfg):
|
116 |
+
super(MagDecoder, self).__init__()
|
117 |
+
self.dense_block = DenseBlock(cfg, depth=4)
|
118 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
119 |
+
self.output_channel = cfg['model_cfg']['output_channel']
|
120 |
+
self.n_fft = cfg['stft_cfg']['n_fft']
|
121 |
+
self.beta = cfg['model_cfg']['beta']
|
122 |
+
|
123 |
+
self.mask_conv = nn.Sequential(
|
124 |
+
nn.ConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
|
125 |
+
nn.Conv2d(self.hid_feature, self.output_channel, (1, 1)),
|
126 |
+
nn.InstanceNorm2d(self.output_channel, affine=True),
|
127 |
+
nn.PReLU(self.output_channel),
|
128 |
+
nn.Conv2d(self.output_channel, self.output_channel, (1, 1))
|
129 |
+
)
|
130 |
+
self.lsigmoid = LearnableSigmoid2D(self.n_fft // 2 + 1, beta=self.beta)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
"""
|
134 |
+
Forward pass for the MagDecoder module.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
- x (torch.Tensor): Input tensor.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
- torch.Tensor: Decoded tensor with magnitude information.
|
141 |
+
"""
|
142 |
+
x = self.dense_block(x)
|
143 |
+
x = self.mask_conv(x)
|
144 |
+
x = rearrange(x, 'b c t f -> b f t c').squeeze(-1)
|
145 |
+
x = self.lsigmoid(x)
|
146 |
+
x = rearrange(x, 'b f t -> b t f').unsqueeze(1)
|
147 |
+
return x
|
148 |
+
|
149 |
+
class PhaseDecoder(nn.Module):
|
150 |
+
"""
|
151 |
+
PhaseDecoder module for decoding phase information.
|
152 |
+
"""
|
153 |
+
def __init__(self, cfg):
|
154 |
+
super(PhaseDecoder, self).__init__()
|
155 |
+
self.dense_block = DenseBlock(cfg, depth=4)
|
156 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
157 |
+
self.output_channel = cfg['model_cfg']['output_channel']
|
158 |
+
|
159 |
+
self.phase_conv = nn.Sequential(
|
160 |
+
nn.ConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
|
161 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
162 |
+
nn.PReLU(self.hid_feature)
|
163 |
+
)
|
164 |
+
|
165 |
+
self.phase_conv_r = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
|
166 |
+
self.phase_conv_i = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
"""
|
170 |
+
Forward pass for the PhaseDecoder module.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
- x (torch.Tensor): Input tensor.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
- torch.Tensor: Decoded tensor with phase information.
|
177 |
+
"""
|
178 |
+
x = self.dense_block(x)
|
179 |
+
x = self.phase_conv(x)
|
180 |
+
x_r = self.phase_conv_r(x)
|
181 |
+
x_i = self.phase_conv_i(x)
|
182 |
+
x = torch.atan2(x_i, x_r)
|
183 |
+
return x
|
models/discriminator.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# References: https://github.com/yxlu-0102/MP-SENet/blob/main/models/discriminator.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
from pesq import pesq
|
7 |
+
from joblib import Parallel, delayed
|
8 |
+
from models.lsigmoid import LearnableSigmoid1D
|
9 |
+
|
10 |
+
def pesq_loss(clean, noisy, sr=16000):
|
11 |
+
try:
|
12 |
+
pesq_score = pesq(sr, clean, noisy, 'wb')
|
13 |
+
except:
|
14 |
+
# error can happen due to silent period
|
15 |
+
pesq_score = -1
|
16 |
+
return pesq_score
|
17 |
+
|
18 |
+
|
19 |
+
def batch_pesq(clean, noisy, cfg):
|
20 |
+
num_worker = cfg['env_setting']['num_workers']
|
21 |
+
pesq_score = Parallel(n_jobs=num_worker)(delayed(pesq_loss)(c, n) for c, n in zip(clean, noisy))
|
22 |
+
pesq_score = np.array(pesq_score)
|
23 |
+
if -1 in pesq_score:
|
24 |
+
return None
|
25 |
+
pesq_score = (pesq_score - 1) / 3.5
|
26 |
+
return torch.FloatTensor(pesq_score)
|
27 |
+
|
28 |
+
|
29 |
+
class MetricDiscriminator(nn.Module):
|
30 |
+
def __init__(self, dim=16, in_channel=2):
|
31 |
+
super(MetricDiscriminator, self).__init__()
|
32 |
+
self.layers = nn.Sequential(
|
33 |
+
nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
|
34 |
+
nn.InstanceNorm2d(dim, affine=True),
|
35 |
+
nn.PReLU(dim),
|
36 |
+
nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
|
37 |
+
nn.InstanceNorm2d(dim*2, affine=True),
|
38 |
+
nn.PReLU(dim*2),
|
39 |
+
nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
|
40 |
+
nn.InstanceNorm2d(dim*4, affine=True),
|
41 |
+
nn.PReLU(dim*4),
|
42 |
+
nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
|
43 |
+
nn.InstanceNorm2d(dim*8, affine=True),
|
44 |
+
nn.PReLU(dim*8),
|
45 |
+
nn.AdaptiveMaxPool2d(1),
|
46 |
+
nn.Flatten(),
|
47 |
+
nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
|
48 |
+
nn.Dropout(0.3),
|
49 |
+
nn.PReLU(dim*4),
|
50 |
+
nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
|
51 |
+
LearnableSigmoid1D(1)
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, x, y):
|
55 |
+
xy = torch.stack((x, y), dim=1)
|
56 |
+
return self.layers(xy)
|
models/generator.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from einops import rearrange
|
4 |
+
from .mamba_block import TFMambaBlock
|
5 |
+
from .codec_module import DenseEncoder, MagDecoder, PhaseDecoder
|
6 |
+
|
7 |
+
class SEMamba(nn.Module):
|
8 |
+
"""
|
9 |
+
SEMamba model for speech enhancement using Mamba blocks.
|
10 |
+
|
11 |
+
This model uses a dense encoder, multiple Mamba blocks, and separate magnitude
|
12 |
+
and phase decoders to process noisy magnitude and phase inputs.
|
13 |
+
"""
|
14 |
+
def __init__(self, cfg):
|
15 |
+
"""
|
16 |
+
Initialize the SEMamba model.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
- cfg: Configuration object containing model parameters.
|
20 |
+
"""
|
21 |
+
super(SEMamba, self).__init__()
|
22 |
+
self.cfg = cfg
|
23 |
+
self.num_tscblocks = cfg['model_cfg']['num_tfmamba'] if cfg['model_cfg']['num_tfmamba'] is not None else 4 # default tfmamba: 4
|
24 |
+
|
25 |
+
# Initialize dense encoder
|
26 |
+
self.dense_encoder = DenseEncoder(cfg)
|
27 |
+
|
28 |
+
# Initialize Mamba blocks
|
29 |
+
self.TSMamba = nn.ModuleList([TFMambaBlock(cfg) for _ in range(self.num_tscblocks)])
|
30 |
+
|
31 |
+
# Initialize decoders
|
32 |
+
self.mask_decoder = MagDecoder(cfg)
|
33 |
+
self.phase_decoder = PhaseDecoder(cfg)
|
34 |
+
|
35 |
+
def forward(self, noisy_mag, noisy_pha):
|
36 |
+
"""
|
37 |
+
Forward pass for the SEMamba model.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
- noisy_mag (torch.Tensor): Noisy magnitude input tensor [B, F, T].
|
41 |
+
- noisy_pha (torch.Tensor): Noisy phase input tensor [B, F, T].
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
- denoised_mag (torch.Tensor): Denoised magnitude tensor [B, F, T].
|
45 |
+
- denoised_pha (torch.Tensor): Denoised phase tensor [B, F, T].
|
46 |
+
- denoised_com (torch.Tensor): Denoised complex tensor [B, F, T, 2].
|
47 |
+
"""
|
48 |
+
# Reshape inputs
|
49 |
+
noisy_mag = rearrange(noisy_mag, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F]
|
50 |
+
noisy_pha = rearrange(noisy_pha, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F]
|
51 |
+
|
52 |
+
# Concatenate magnitude and phase inputs
|
53 |
+
x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F]
|
54 |
+
|
55 |
+
# Encode input
|
56 |
+
x = self.dense_encoder(x)
|
57 |
+
|
58 |
+
# Apply Mamba blocks
|
59 |
+
for block in self.TSMamba:
|
60 |
+
x = block(x)
|
61 |
+
|
62 |
+
# Decode magnitude and phase
|
63 |
+
denoised_mag = rearrange(self.mask_decoder(x) * noisy_mag, 'b c t f -> b f t c').squeeze(-1)
|
64 |
+
denoised_pha = rearrange(self.phase_decoder(x), 'b c t f -> b f t c').squeeze(-1)
|
65 |
+
|
66 |
+
# Combine denoised magnitude and phase into a complex representation
|
67 |
+
denoised_com = torch.stack(
|
68 |
+
(denoised_mag * torch.cos(denoised_pha), denoised_mag * torch.sin(denoised_pha)),
|
69 |
+
dim=-1
|
70 |
+
)
|
71 |
+
|
72 |
+
return denoised_mag, denoised_pha, denoised_com
|
models/loss.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/models/generator.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
from pesq import pesq
|
7 |
+
from joblib import Parallel, delayed
|
8 |
+
|
9 |
+
def phase_losses(phase_r, phase_g, cfg):
|
10 |
+
"""
|
11 |
+
Calculate phase losses including in-phase loss, gradient delay loss,
|
12 |
+
and integrated absolute frequency loss between reference and generated phases.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
phase_r (torch.Tensor): Reference phase tensor of shape (batch, freq, time).
|
16 |
+
phase_g (torch.Tensor): Generated phase tensor of shape (batch, freq, time).
|
17 |
+
h (object): Configuration object containing parameters like n_fft.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
tuple: Tuple containing in-phase loss, gradient delay loss, and integrated absolute frequency loss.
|
21 |
+
"""
|
22 |
+
dim_freq = cfg['stft_cfg']['n_fft'] // 2 + 1 # Calculate frequency dimension
|
23 |
+
dim_time = phase_r.size(-1) # Calculate time dimension
|
24 |
+
|
25 |
+
# Construct gradient delay matrix
|
26 |
+
gd_matrix = (torch.triu(torch.ones(dim_freq, dim_freq), diagonal=1) -
|
27 |
+
torch.triu(torch.ones(dim_freq, dim_freq), diagonal=2) -
|
28 |
+
torch.eye(dim_freq)).to(phase_g.device)
|
29 |
+
|
30 |
+
# Apply gradient delay matrix to reference and generated phases
|
31 |
+
gd_r = torch.matmul(phase_r.permute(0, 2, 1), gd_matrix)
|
32 |
+
gd_g = torch.matmul(phase_g.permute(0, 2, 1), gd_matrix)
|
33 |
+
|
34 |
+
# Construct integrated absolute frequency matrix
|
35 |
+
iaf_matrix = (torch.triu(torch.ones(dim_time, dim_time), diagonal=1) -
|
36 |
+
torch.triu(torch.ones(dim_time, dim_time), diagonal=2) -
|
37 |
+
torch.eye(dim_time)).to(phase_g.device)
|
38 |
+
|
39 |
+
# Apply integrated absolute frequency matrix to reference and generated phases
|
40 |
+
iaf_r = torch.matmul(phase_r, iaf_matrix)
|
41 |
+
iaf_g = torch.matmul(phase_g, iaf_matrix)
|
42 |
+
|
43 |
+
# Calculate losses
|
44 |
+
ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
|
45 |
+
gd_loss = torch.mean(anti_wrapping_function(gd_r - gd_g))
|
46 |
+
iaf_loss = torch.mean(anti_wrapping_function(iaf_r - iaf_g))
|
47 |
+
|
48 |
+
return ip_loss, gd_loss, iaf_loss
|
49 |
+
|
50 |
+
def anti_wrapping_function(x):
|
51 |
+
"""
|
52 |
+
Anti-wrapping function to adjust phase values within the range of -pi to pi.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
x (torch.Tensor): Input tensor representing phase differences.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
torch.Tensor: Adjusted tensor with phase values wrapped within -pi to pi.
|
59 |
+
"""
|
60 |
+
return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
|
61 |
+
|
62 |
+
def compute_stft(y: torch.Tensor, n_fft: int, hop_size: int, win_size: int, center: bool, compress_factor: float = 1.0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
63 |
+
"""
|
64 |
+
Compute the Short-Time Fourier Transform (STFT) and return magnitude, phase, and complex components.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
y (torch.Tensor): Input signal tensor.
|
68 |
+
n_fft (int): Number of FFT points.
|
69 |
+
hop_size (int): Hop size for STFT.
|
70 |
+
win_size (int): Window size for STFT.
|
71 |
+
center (bool): Whether to pad the input on both sides.
|
72 |
+
compress_factor (float, optional): Compression factor for magnitude. Defaults to 1.0.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Magnitude, phase, and complex components.
|
76 |
+
"""
|
77 |
+
eps = torch.finfo(y.dtype).eps
|
78 |
+
hann_window = torch.hann_window(win_size).to(y.device)
|
79 |
+
|
80 |
+
stft_spec = torch.stft(
|
81 |
+
y,
|
82 |
+
n_fft=n_fft,
|
83 |
+
hop_length=hop_size,
|
84 |
+
win_length=win_size,
|
85 |
+
window=hann_window,
|
86 |
+
center=center,
|
87 |
+
pad_mode='reflect',
|
88 |
+
normalized=False,
|
89 |
+
return_complex=True
|
90 |
+
)
|
91 |
+
|
92 |
+
real_part = stft_spec.real
|
93 |
+
imag_part = stft_spec.imag
|
94 |
+
|
95 |
+
mag = torch.sqrt( real_part.pow(2) * imag_part.pow(2) + eps )
|
96 |
+
pha = torch.atan2( real_part + eps, imag_part + eps )
|
97 |
+
|
98 |
+
mag = torch.pow(mag, compress_factor)
|
99 |
+
com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1)
|
100 |
+
|
101 |
+
return mag, pha, com
|
102 |
+
|
103 |
+
def pesq_score(utts_r, utts_g, cfg):
|
104 |
+
"""
|
105 |
+
Calculate PESQ (Perceptual Evaluation of Speech Quality) score for pairs of reference and generated utterances.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
utts_r (list of torch.Tensor): List of reference utterances.
|
109 |
+
utts_g (list of torch.Tensor): List of generated utterances.
|
110 |
+
h (object): Configuration object containing parameters like sampling_rate.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
float: Mean PESQ score across all pairs of utterances.
|
114 |
+
"""
|
115 |
+
def eval_pesq(clean_utt, esti_utt, sr):
|
116 |
+
"""
|
117 |
+
Evaluate PESQ score for a single pair of clean and estimated utterances.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
clean_utt (np.ndarray): Clean reference utterance.
|
121 |
+
esti_utt (np.ndarray): Estimated generated utterance.
|
122 |
+
sr (int): Sampling rate.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
float: PESQ score or -1 in case of an error.
|
126 |
+
"""
|
127 |
+
try:
|
128 |
+
pesq_score = pesq(sr, clean_utt, esti_utt)
|
129 |
+
except Exception as e:
|
130 |
+
# Error can happen due to silent period or other issues
|
131 |
+
print(f"Error computing PESQ score: {e}")
|
132 |
+
pesq_score = -1
|
133 |
+
return pesq_score
|
134 |
+
|
135 |
+
# Parallel processing of PESQ score computation
|
136 |
+
pesq_scores = Parallel(n_jobs=30)(delayed(eval_pesq)(
|
137 |
+
utts_r[i].squeeze().cpu().numpy(),
|
138 |
+
utts_g[i].squeeze().cpu().numpy(),
|
139 |
+
cfg['stft_cfg']['sampling_rate']
|
140 |
+
) for i in range(len(utts_r)))
|
141 |
+
|
142 |
+
# Calculate mean PESQ score
|
143 |
+
pesq_score = np.mean(pesq_scores)
|
144 |
+
return pesq_score
|
145 |
+
|
models/lsigmoid.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/utils.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
class LearnableSigmoid1D(nn.Module):
|
7 |
+
"""
|
8 |
+
Learnable Sigmoid Activation Function for 1D inputs.
|
9 |
+
|
10 |
+
This module applies a learnable slope parameter to the sigmoid activation function.
|
11 |
+
"""
|
12 |
+
def __init__(self, in_features, beta=1):
|
13 |
+
"""
|
14 |
+
Initialize the LearnableSigmoid1D module.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
- in_features (int): Number of input features.
|
18 |
+
- beta (float, optional): Scaling factor for the sigmoid function. Defaults to 1.
|
19 |
+
"""
|
20 |
+
super(LearnableSigmoid1D, self).__init__()
|
21 |
+
self.beta = beta
|
22 |
+
self.slope = nn.Parameter(torch.ones(in_features))
|
23 |
+
self.slope.requires_grad = True
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
"""
|
27 |
+
Forward pass for the LearnableSigmoid1D module.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
- x (torch.Tensor): Input tensor.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
- torch.Tensor: Output tensor after applying the learnable sigmoid activation.
|
34 |
+
"""
|
35 |
+
return self.beta * torch.sigmoid(self.slope * x)
|
36 |
+
|
37 |
+
class LearnableSigmoid2D(nn.Module):
|
38 |
+
"""
|
39 |
+
Learnable Sigmoid Activation Function for 2D inputs.
|
40 |
+
|
41 |
+
This module applies a learnable slope parameter to the sigmoid activation function for 2D inputs.
|
42 |
+
"""
|
43 |
+
def __init__(self, in_features, beta=1):
|
44 |
+
"""
|
45 |
+
Initialize the LearnableSigmoid2D module.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
- in_features (int): Number of input features.
|
49 |
+
- beta (float, optional): Scaling factor for the sigmoid function. Defaults to 1.
|
50 |
+
"""
|
51 |
+
super(LearnableSigmoid2D, self).__init__()
|
52 |
+
self.beta = beta
|
53 |
+
self.slope = nn.Parameter(torch.ones(in_features, 1))
|
54 |
+
self.slope.requires_grad = True
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
"""
|
58 |
+
Forward pass for the LearnableSigmoid2D module.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
- x (torch.Tensor): Input tensor.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
- torch.Tensor: Output tensor after applying the learnable sigmoid activation.
|
65 |
+
"""
|
66 |
+
return self.beta * torch.sigmoid(self.slope * x)
|
models/mamba_block.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/state-spaces/mamba/blob/9127d1f47f367f5c9cc49c73ad73557089d02cb8/mamba_ssm/models/mixer_seq_simple.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.nn import init
|
7 |
+
from torch.nn.parameter import Parameter
|
8 |
+
from functools import partial
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
from mamba_ssm.modules.mamba_simple import Mamba, Block
|
12 |
+
from mamba_ssm.models.mixer_seq_simple import _init_weights
|
13 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm
|
14 |
+
|
15 |
+
# github: https://github.com/state-spaces/mamba/blob/9127d1f47f367f5c9cc49c73ad73557089d02cb8/mamba_ssm/models/mixer_seq_simple.py
|
16 |
+
def create_block(
|
17 |
+
d_model, cfg, layer_idx=0, rms_norm=True, fused_add_norm=False, residual_in_fp32=False,
|
18 |
+
):
|
19 |
+
d_state = cfg['model_cfg']['d_state'] # 16
|
20 |
+
d_conv = cfg['model_cfg']['d_conv'] # 4
|
21 |
+
expand = cfg['model_cfg']['expand'] # 4
|
22 |
+
norm_epsilon = cfg['model_cfg']['norm_epsilon'] # 0.00001
|
23 |
+
|
24 |
+
mixer_cls = partial(Mamba, layer_idx=layer_idx, d_state=d_state, d_conv=d_conv, expand=expand)
|
25 |
+
norm_cls = partial(
|
26 |
+
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon
|
27 |
+
)
|
28 |
+
block = Block(
|
29 |
+
d_model,
|
30 |
+
mixer_cls,
|
31 |
+
norm_cls=norm_cls,
|
32 |
+
fused_add_norm=fused_add_norm,
|
33 |
+
residual_in_fp32=residual_in_fp32,
|
34 |
+
)
|
35 |
+
block.layer_idx = layer_idx
|
36 |
+
return block
|
37 |
+
|
38 |
+
class MambaBlock(nn.Module):
|
39 |
+
def __init__(self, in_channels, cfg):
|
40 |
+
super(MambaBlock, self).__init__()
|
41 |
+
n_layer = 1
|
42 |
+
self.forward_blocks = nn.ModuleList( create_block(in_channels, cfg) for i in range(n_layer) )
|
43 |
+
self.backward_blocks = nn.ModuleList( create_block(in_channels, cfg) for i in range(n_layer) )
|
44 |
+
|
45 |
+
self.apply(
|
46 |
+
partial(
|
47 |
+
_init_weights,
|
48 |
+
n_layer=n_layer,
|
49 |
+
)
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x_forward, x_backward = x.clone(), torch.flip(x, [1])
|
54 |
+
resi_forward, resi_backward = None, None
|
55 |
+
|
56 |
+
# Forward
|
57 |
+
for layer in self.forward_blocks:
|
58 |
+
x_forward, resi_forward = layer(x_forward, resi_forward)
|
59 |
+
y_forward = (x_forward + resi_forward) if resi_forward is not None else x_forward
|
60 |
+
|
61 |
+
# Backward
|
62 |
+
for layer in self.backward_blocks:
|
63 |
+
x_backward, resi_backward = layer(x_backward, resi_backward)
|
64 |
+
y_backward = torch.flip((x_backward + resi_backward), [1]) if resi_backward is not None else torch.flip(x_backward, [1])
|
65 |
+
|
66 |
+
return torch.cat([y_forward, y_backward], -1)
|
67 |
+
|
68 |
+
class TFMambaBlock(nn.Module):
|
69 |
+
"""
|
70 |
+
Temporal-Frequency Mamba block for sequence modeling.
|
71 |
+
|
72 |
+
Attributes:
|
73 |
+
cfg (Config): Configuration for the block.
|
74 |
+
time_mamba (MambaBlock): Mamba block for temporal dimension.
|
75 |
+
freq_mamba (MambaBlock): Mamba block for frequency dimension.
|
76 |
+
tlinear (ConvTranspose1d): ConvTranspose1d layer for temporal dimension.
|
77 |
+
flinear (ConvTranspose1d): ConvTranspose1d layer for frequency dimension.
|
78 |
+
"""
|
79 |
+
def __init__(self, cfg):
|
80 |
+
super(TFMambaBlock, self).__init__()
|
81 |
+
self.cfg = cfg
|
82 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
83 |
+
|
84 |
+
# Initialize Mamba blocks
|
85 |
+
self.time_mamba = MambaBlock(in_channels=self.hid_feature, cfg=cfg)
|
86 |
+
self.freq_mamba = MambaBlock(in_channels=self.hid_feature, cfg=cfg)
|
87 |
+
|
88 |
+
# Initialize ConvTranspose1d layers
|
89 |
+
self.tlinear = nn.ConvTranspose1d(self.hid_feature * 2, self.hid_feature, 1, stride=1)
|
90 |
+
self.flinear = nn.ConvTranspose1d(self.hid_feature * 2, self.hid_feature, 1, stride=1)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
"""
|
94 |
+
Forward pass of the TFMamba block.
|
95 |
+
|
96 |
+
Parameters:
|
97 |
+
x (Tensor): Input tensor with shape (batch, channels, time, freq).
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
Tensor: Output tensor after applying temporal and frequency Mamba blocks.
|
101 |
+
"""
|
102 |
+
b, c, t, f = x.size()
|
103 |
+
|
104 |
+
x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
|
105 |
+
x = self.tlinear( self.time_mamba(x).permute(0,2,1) ).permute(0,2,1) + x
|
106 |
+
x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
|
107 |
+
x = self.flinear( self.freq_mamba(x).permute(0,2,1) ).permute(0,2,1) + x
|
108 |
+
x = x.view(b, t, f, c).permute(0, 3, 1, 2)
|
109 |
+
return x
|
110 |
+
|
models/pcs400.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import numpy as np
|
5 |
+
import argparse
|
6 |
+
import librosa
|
7 |
+
import scipy
|
8 |
+
|
9 |
+
# PCS400 parameters
|
10 |
+
PCS400 = np.ones(201)
|
11 |
+
PCS400[0:3] = 1
|
12 |
+
PCS400[3:5] = 1.070175439
|
13 |
+
PCS400[5:8] = 1.182456140
|
14 |
+
PCS400[8:10] = 1.287719298
|
15 |
+
PCS400[10:110] = 1.4 # Pre Set
|
16 |
+
PCS400[110:130] = 1.322807018
|
17 |
+
PCS400[130:160] = 1.238596491
|
18 |
+
PCS400[160:190] = 1.161403509
|
19 |
+
PCS400[190:202] = 1.077192982
|
20 |
+
|
21 |
+
maxv = np.iinfo(np.int16).max
|
22 |
+
|
23 |
+
def Sp_and_phase(signal):
|
24 |
+
signal_length = signal.shape[0]
|
25 |
+
n_fft = 400
|
26 |
+
hop_length = 100
|
27 |
+
y_pad = librosa.util.fix_length(signal, size=signal_length + n_fft // 2)
|
28 |
+
|
29 |
+
F = librosa.stft(y_pad, n_fft=400, hop_length=100, win_length=400, window=scipy.signal.windows.hamming(400))
|
30 |
+
Lp = PCS400 * np.transpose(np.log1p(np.abs(F)), (1, 0))
|
31 |
+
phase = np.angle(F)
|
32 |
+
|
33 |
+
NLp = np.transpose(Lp, (1, 0))
|
34 |
+
|
35 |
+
return NLp, phase, signal_length
|
36 |
+
|
37 |
+
|
38 |
+
def SP_to_wav(mag, phase, signal_length):
|
39 |
+
mag = np.expm1(mag)
|
40 |
+
Rec = np.multiply(mag, np.exp(1j*phase))
|
41 |
+
result = librosa.istft(Rec,
|
42 |
+
hop_length=100,
|
43 |
+
win_length=400,
|
44 |
+
window=scipy.signal.windows.hamming(400),
|
45 |
+
length=signal_length)
|
46 |
+
return result
|
47 |
+
|
48 |
+
def cal_pcs(signal_wav):
|
49 |
+
noisy_LP, Nphase, signal_length = Sp_and_phase(signal_wav.squeeze())
|
50 |
+
enhanced_wav = SP_to_wav(noisy_LP, Nphase, signal_length)
|
51 |
+
enhanced_wav = enhanced_wav/np.max(abs(enhanced_wav))
|
52 |
+
|
53 |
+
return enhanced_wav
|
models/stfts.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
def mag_phase_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True, addeps=False):
|
5 |
+
"""
|
6 |
+
Compute magnitude and phase using STFT.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
y (torch.Tensor): Input audio signal.
|
10 |
+
n_fft (int): FFT size.
|
11 |
+
hop_size (int): Hop size.
|
12 |
+
win_size (int): Window size.
|
13 |
+
compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0.
|
14 |
+
center (bool, optional): Whether to center the signal before padding. Defaults to True.
|
15 |
+
eps (bool, optional): Whether adding epsilon to magnitude and phase or not. Defaults to False.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
tuple: Magnitude, phase, and complex representation of the STFT.
|
19 |
+
"""
|
20 |
+
#eps = torch.finfo(y.dtype).eps
|
21 |
+
eps = 1e-10
|
22 |
+
hann_window = torch.hann_window(win_size).to(y.device)
|
23 |
+
stft_spec = torch.stft(
|
24 |
+
y, n_fft,
|
25 |
+
hop_length=hop_size,
|
26 |
+
win_length=win_size,
|
27 |
+
window=hann_window,
|
28 |
+
center=center,
|
29 |
+
pad_mode='reflect',
|
30 |
+
normalized=False,
|
31 |
+
return_complex=True)
|
32 |
+
|
33 |
+
if addeps==False:
|
34 |
+
mag = torch.abs(stft_spec)
|
35 |
+
pha = torch.angle(stft_spec)
|
36 |
+
else:
|
37 |
+
real_part = stft_spec.real
|
38 |
+
imag_part = stft_spec.imag
|
39 |
+
mag = torch.sqrt(real_part.pow(2) + imag_part.pow(2) + eps)
|
40 |
+
pha = torch.atan2(imag_part + eps, real_part + eps)
|
41 |
+
# Compress the magnitude
|
42 |
+
mag = torch.pow(mag, compress_factor)
|
43 |
+
com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1)
|
44 |
+
return mag, pha, com
|
45 |
+
|
46 |
+
|
47 |
+
def mag_phase_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
|
48 |
+
"""
|
49 |
+
Inverse STFT to reconstruct the audio signal from magnitude and phase.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
mag (torch.Tensor): Magnitude of the STFT.
|
53 |
+
pha (torch.Tensor): Phase of the STFT.
|
54 |
+
n_fft (int): FFT size.
|
55 |
+
hop_size (int): Hop size.
|
56 |
+
win_size (int): Window size.
|
57 |
+
compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0.
|
58 |
+
center (bool, optional): Whether to center the signal before padding. Defaults to True.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
torch.Tensor: Reconstructed audio signal.
|
62 |
+
"""
|
63 |
+
mag = torch.pow(mag, 1.0 / compress_factor)
|
64 |
+
com = torch.complex(mag * torch.cos(pha), mag * torch.sin(pha))
|
65 |
+
hann_window = torch.hann_window(win_size).to(com.device)
|
66 |
+
wav = torch.istft(
|
67 |
+
com,
|
68 |
+
n_fft,
|
69 |
+
hop_length=hop_size,
|
70 |
+
win_length=win_size,
|
71 |
+
window=hann_window,
|
72 |
+
center=center)
|
73 |
+
return wav
|
recipes/SEMamba_advanced.yaml
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Environment Settings
|
2 |
+
# These settings specify the hardware and distributed setup for the model training.
|
3 |
+
# Adjust `num_gpus` and `dist_config` according to your distributed training environment.
|
4 |
+
env_setting:
|
5 |
+
num_gpus: 2 # Number of GPUs. Now we don't support CPU mode.
|
6 |
+
num_workers: 20 # Number of worker threads for data loading.
|
7 |
+
seed: 1234 # Seed for random number generators to ensure reproducibility.
|
8 |
+
stdout_interval: 10
|
9 |
+
checkpoint_interval: 1000 # save model to ckpt every N steps
|
10 |
+
validation_interval: 1000
|
11 |
+
summary_interval: 100
|
12 |
+
dist_cfg:
|
13 |
+
dist_backend: nccl # Distributed training backend, 'nccl' for NVIDIA GPUs.
|
14 |
+
dist_url: tcp://localhost:19477 # URL for initializing distributed training.
|
15 |
+
world_size: 1 # Total number of processes in the distributed training.
|
16 |
+
|
17 |
+
# Datapath Configuratoin
|
18 |
+
data_cfg:
|
19 |
+
train_clean_json: data/train_clean.json
|
20 |
+
train_noisy_json: data/train_noisy.json
|
21 |
+
valid_clean_json: data/valid_clean.json
|
22 |
+
valid_noisy_json: data/valid_noisy.json
|
23 |
+
test_clean_json: data/test_clean.json
|
24 |
+
test_noisy_json: data/test_noisy.json
|
25 |
+
|
26 |
+
# Training Configuration
|
27 |
+
# This section details parameters that directly influence the training process,
|
28 |
+
# including batch sizes, learning rates, and optimizer specifics.
|
29 |
+
training_cfg:
|
30 |
+
training_epochs: 200 # Training epoch.
|
31 |
+
batch_size: 4 # Training batch size.
|
32 |
+
learning_rate: 0.0005 # Initial learning rate.
|
33 |
+
adam_b1: 0.8 # Beta1 hyperparameter for the AdamW optimizer.
|
34 |
+
adam_b2: 0.99 # Beta2 hyperparameter for the AdamW optimizer.
|
35 |
+
lr_decay: 0.99 # Learning rate decay per epoch.
|
36 |
+
segment_size: 32000 # Audio segment size used during training, dependent on sampling rate.
|
37 |
+
loss:
|
38 |
+
metric: 0.05
|
39 |
+
magnitude: 0.9
|
40 |
+
phase: 0.3
|
41 |
+
complex: 0.1
|
42 |
+
time: 0.2
|
43 |
+
consistancy: 0.1
|
44 |
+
use_PCS400: False # Use PCS or not
|
45 |
+
|
46 |
+
# STFT Configuration
|
47 |
+
# Configuration for Short-Time Fourier Transform (STFT), crucial for audio processing models.
|
48 |
+
stft_cfg:
|
49 |
+
sampling_rate: 16000 # Audio sampling rate in Hz.
|
50 |
+
n_fft: 400 # FFT components for transforming audio signals.
|
51 |
+
hop_size: 100 # Samples between successive frames.
|
52 |
+
win_size: 400 # Window size used in FFT.
|
53 |
+
|
54 |
+
# Model Configuration
|
55 |
+
# Defines the architecture specifics of the model, including layer configurations and feature compression.
|
56 |
+
model_cfg:
|
57 |
+
hid_feature: 64 # Channels in dense layers.
|
58 |
+
compress_factor: 0.3 # Compression factor applied to extracted features.
|
59 |
+
num_tfmamba: 4 # Number of Time-Frequency Mamba (TFMamba) blocks in the model.
|
60 |
+
d_state: 16 # Dimensionality of the state vector in Mamba blocks.
|
61 |
+
d_conv: 4 # Convolutional layer dimensionality within Mamba blocks.
|
62 |
+
expand: 4 # Expansion factor for the layers within the Mamba blocks.
|
63 |
+
norm_epsilon: 0.00001 # Numerical stability in normalization layers within the Mamba blocks.
|
64 |
+
beta: 2.0 # Hyperparameter for the Learnable Sigmoid function.
|
65 |
+
input_channel: 2 # Magnitude and Phase
|
66 |
+
output_channel: 1 # Single Channel Speech Enhancement
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
packaging
|
2 |
+
librosa
|
3 |
+
soundfile
|
4 |
+
pyyaml
|
5 |
+
argparse
|
6 |
+
tensorboard
|
7 |
+
pesq
|
8 |
+
einops
|
9 |
+
matplotlib
|
10 |
+
torch==2.5.1
|
11 |
+
torchaudio==2.5.1
|
12 |
+
numpy==1.26.4
|
13 |
+
ultralytics
|
14 |
+
moviepy
|
15 |
+
supervision
|
16 |
+
opencv-python
|
17 |
+
ffmpeg-python
|
18 |
+
decord==0.6.0
|
19 |
+
pytorch_lightning==1.9.0
|
20 |
+
typeguard==2.13.3
|
21 |
+
torch_complex
|
22 |
+
rich
|
yolov8n-face.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d17b38523a994b13ee604b67f02791ca0f43b9f446a32fd7bc44e17c56ead077
|
3 |
+
size 6250099
|