Spaces:
Sleeping
Sleeping
roychao19477
commited on
Commit
·
bd9ffb1
1
Parent(s):
a66fd6a
Add application file
Browse files- README.md +6 -7
- app.py +376 -0
- mamba_ssm/.DS_Store +0 -0
- mamba_ssm/__init__.py +5 -0
- mamba_ssm/models/__init__.py +0 -0
- mamba_ssm/models/config_mamba.py +15 -0
- mamba_ssm/models/mixer_seq_simple.py +264 -0
- mamba_ssm/modules/__init__.py +0 -0
- mamba_ssm/modules/mamba_simple.py +353 -0
- mamba_ssm/ops/__init__.py +0 -0
- mamba_ssm/ops/selective_scan_interface.py +357 -0
- mamba_ssm/ops/triton/__init__.py +0 -0
- mamba_ssm/ops/triton/layernorm.py +635 -0
- mamba_ssm/ops/triton/selective_state_update.py +263 -0
- mamba_ssm/utils/__init__.py +0 -0
- mamba_ssm/utils/generation.py +387 -0
- mamba_ssm/utils/hf.py +23 -0
- models/codec_module.py +183 -0
- models/discriminator.py +56 -0
- models/generator.py +72 -0
- models/loss.py +145 -0
- models/lsigmoid.py +66 -0
- models/mamba_block.py +110 -0
- models/pcs400.py +53 -0
- models/stfts.py +73 -0
- recipes/SEMamba_advanced.yaml +66 -0
- requirements.txt +22 -0
- yolov8n-face.pt +3 -0
README.md
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
colorFrom: green
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
short_description:
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
-
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Dev
|
| 3 |
+
colorFrom: purple
|
|
|
|
| 4 |
colorTo: indigo
|
| 5 |
sdk: gradio
|
| 6 |
+
sdk_version: 5.31.0
|
| 7 |
app_file: app.py
|
| 8 |
pinned: false
|
| 9 |
+
short_description: Dev
|
| 10 |
+
tags:
|
| 11 |
+
- Useless
|
| 12 |
---
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import shlex
|
| 2 |
+
import subprocess
|
| 3 |
+
import spaces
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
import glob
|
| 8 |
+
import gradio as gr
|
| 9 |
+
|
| 10 |
+
# install packages for mamba
|
| 11 |
+
def install_mamba():
|
| 12 |
+
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
|
| 13 |
+
|
| 14 |
+
def clone_github():
|
| 15 |
+
subprocess.run([
|
| 16 |
+
"git", "clone",
|
| 17 |
+
f"https://RoyChao19477:{os.environ['GITHUB_TOKEN']}@github.com/RoyChao19477/for_HF_AVSEMamba.git",
|
| 18 |
+
])
|
| 19 |
+
# move all files except README.md
|
| 20 |
+
for item in glob.glob("for_HF_AVSEMamba/*"):
|
| 21 |
+
if os.path.basename(item) != "README.md":
|
| 22 |
+
if os.path.isdir(item):
|
| 23 |
+
shutil.move(item, ".")
|
| 24 |
+
else:
|
| 25 |
+
shutil.move(item, os.path.join(".", os.path.basename(item)))
|
| 26 |
+
|
| 27 |
+
#shutil.rmtree("tmp_repo")
|
| 28 |
+
#subprocess.run(["ls"], check=True)
|
| 29 |
+
|
| 30 |
+
install_mamba()
|
| 31 |
+
clone_github()
|
| 32 |
+
|
| 33 |
+
ABOUT = """
|
| 34 |
+
# SEMamba: Speech Enhancement
|
| 35 |
+
A Mamba-based model that denoises real-world audio.
|
| 36 |
+
Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
import torch
|
| 41 |
+
import ffmpeg
|
| 42 |
+
import torchaudio
|
| 43 |
+
import torchaudio.transforms as T
|
| 44 |
+
import yaml
|
| 45 |
+
import librosa
|
| 46 |
+
import librosa.display
|
| 47 |
+
import matplotlib
|
| 48 |
+
import numpy as np
|
| 49 |
+
import soundfile as sf
|
| 50 |
+
import matplotlib.pyplot as plt
|
| 51 |
+
from models.stfts import mag_phase_stft, mag_phase_istft
|
| 52 |
+
from models.generator import SEMamba
|
| 53 |
+
from models.pcs400 import cal_pcs
|
| 54 |
+
from ultralytics import YOLO
|
| 55 |
+
import supervision as sv
|
| 56 |
+
|
| 57 |
+
import gradio as gr
|
| 58 |
+
import cv2
|
| 59 |
+
import os
|
| 60 |
+
import tempfile
|
| 61 |
+
from ultralytics import YOLO
|
| 62 |
+
from moviepy import ImageSequenceClip
|
| 63 |
+
from moviepy.video import fx as vfx
|
| 64 |
+
from scipy.io import wavfile
|
| 65 |
+
from avse_code import run_avse
|
| 66 |
+
|
| 67 |
+
# Load face detector
|
| 68 |
+
model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
from decord import VideoReader, cpu
|
| 72 |
+
from model import AVSEModule
|
| 73 |
+
from config import sampling_rate
|
| 74 |
+
import spaces
|
| 75 |
+
|
| 76 |
+
# Load model once globally
|
| 77 |
+
#ckpt_path = "ckpts/ep215_0906.oat.ckpt"
|
| 78 |
+
#model = AVSEModule.load_from_checkpoint(ckpt_path)
|
| 79 |
+
avse_model = AVSEModule()
|
| 80 |
+
#avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
|
| 81 |
+
avse_state_dict = torch.load("ckpts/ep220_0908.oat.ckpt")
|
| 82 |
+
avse_model.load_state_dict(avse_state_dict, strict=True)
|
| 83 |
+
avse_model.to("cuda")
|
| 84 |
+
avse_model.eval()
|
| 85 |
+
|
| 86 |
+
@spaces.GPU
|
| 87 |
+
def run_avse_inference(video_path, audio_path):
|
| 88 |
+
estimated = run_avse(video_path, audio_path)
|
| 89 |
+
# Load audio
|
| 90 |
+
#noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
|
| 91 |
+
#noisy = torch.tensor(noisy).unsqueeze(0) # (1, N)
|
| 92 |
+
noisy = wavfile.read(audio_path)[1].astype(np.float32) / (2 ** 15)
|
| 93 |
+
|
| 94 |
+
# Norm.
|
| 95 |
+
#noisy = noisy * (0.8 / np.max(np.abs(noisy)))
|
| 96 |
+
|
| 97 |
+
# Load grayscale video
|
| 98 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
| 99 |
+
frames = vr.get_batch(list(range(len(vr)))).asnumpy()
|
| 100 |
+
bg_frames = np.array([
|
| 101 |
+
cv2.cvtColor(frames[i], cv2.COLOR_RGB2GRAY) for i in range(len(frames))
|
| 102 |
+
]).astype(np.float32)
|
| 103 |
+
bg_frames /= 255.0
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# Combine into input dict (match what model.enhance expects)
|
| 107 |
+
data = {
|
| 108 |
+
"noisy_audio": noisy,
|
| 109 |
+
"video_frames": bg_frames[np.newaxis, ...]
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
estimated = avse_model.enhance(data).reshape(-1)
|
| 114 |
+
|
| 115 |
+
# Save result
|
| 116 |
+
tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
|
| 117 |
+
sf.write(tmp_wav, estimated, samplerate=sampling_rate)
|
| 118 |
+
|
| 119 |
+
return tmp_wav
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def extract_resampled_audio(video_path, target_sr=16000):
|
| 123 |
+
# Step 1: extract audio via torchaudio
|
| 124 |
+
# (moviepy will still extract it to wav temp file)
|
| 125 |
+
tmp_audio_path = tempfile.mktemp(suffix=".wav")
|
| 126 |
+
subprocess.run(["ffmpeg", "-y", "-i", video_path, "-vn", "-acodec", "pcm_s16le", "-ar", "44100", tmp_audio_path])
|
| 127 |
+
|
| 128 |
+
# Step 2: Load and resample
|
| 129 |
+
waveform, sr = torchaudio.load(tmp_audio_path)
|
| 130 |
+
if sr != target_sr:
|
| 131 |
+
resampler = T.Resample(orig_freq=sr, new_freq=target_sr)
|
| 132 |
+
waveform = resampler(waveform)
|
| 133 |
+
|
| 134 |
+
# Step 3: Save resampled audio
|
| 135 |
+
resampled_audio_path = tempfile.mktemp(suffix="_16k.wav")
|
| 136 |
+
torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
|
| 137 |
+
return resampled_audio_path
|
| 138 |
+
|
| 139 |
+
@spaces.GPU
|
| 140 |
+
def extract_faces(video_file):
|
| 141 |
+
# Step 0: Check resolution
|
| 142 |
+
cap = cv2.VideoCapture(video_file)
|
| 143 |
+
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
|
| 144 |
+
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
| 145 |
+
cap.release()
|
| 146 |
+
|
| 147 |
+
# Step 1: Downsample if needed
|
| 148 |
+
if width > 1280 or height > 720:
|
| 149 |
+
resized_path = tempfile.mktemp(suffix=".mp4")
|
| 150 |
+
subprocess.run([
|
| 151 |
+
"ffmpeg", "-y", "-i", video_file,
|
| 152 |
+
"-vf", "scale='min(1280,iw)':-2",
|
| 153 |
+
"-c:v", "libx264", "-crf", "28",
|
| 154 |
+
"-preset", "fast", "-an", resized_path
|
| 155 |
+
])
|
| 156 |
+
video_file = resized_path
|
| 157 |
+
|
| 158 |
+
cap = cv2.VideoCapture(video_file)
|
| 159 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 160 |
+
frames = []
|
| 161 |
+
|
| 162 |
+
while True:
|
| 163 |
+
ret, frame = cap.read()
|
| 164 |
+
if not ret:
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
# Inference
|
| 168 |
+
results = model(frame, verbose=False)[0]
|
| 169 |
+
for box in results.boxes:
|
| 170 |
+
# version 1
|
| 171 |
+
# x1, y1, x2, y2 = map(int, box.xyxy[0])
|
| 172 |
+
|
| 173 |
+
# version 2
|
| 174 |
+
h, w, _ = frame.shape
|
| 175 |
+
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
| 176 |
+
pad_ratio = 0.5 # 30% padding
|
| 177 |
+
|
| 178 |
+
dx = (x2 - x1) * pad_ratio
|
| 179 |
+
dy = (y2 - y1) * pad_ratio
|
| 180 |
+
|
| 181 |
+
x1 = int(max(0, x1 - dx))
|
| 182 |
+
y1 = int(max(0, y1 - dy))
|
| 183 |
+
x2 = int(min(w, x2 + dx))
|
| 184 |
+
y2 = int(min(h, y2 + dy))
|
| 185 |
+
# Added for v3
|
| 186 |
+
shift_down = int(0.1 * (y2 - y1))
|
| 187 |
+
y1 = int(min(max(0, y1 + shift_down), h))
|
| 188 |
+
y2 = int(min(max(0, y2 + shift_down), h))
|
| 189 |
+
face_crop = frame[y1:y2, x1:x2]
|
| 190 |
+
if face_crop.size != 0:
|
| 191 |
+
resized = cv2.resize(face_crop, (224, 224))
|
| 192 |
+
frames.append(resized)
|
| 193 |
+
|
| 194 |
+
#h_crop, w_crop = face_crop.shape[:2]
|
| 195 |
+
#side = min(h_crop, w_crop)
|
| 196 |
+
#start_y = (h_crop - side) // 2
|
| 197 |
+
#start_x = (w_crop - side) // 2
|
| 198 |
+
#square_crop = face_crop[start_y:start_y+side, start_x:start_x+side]
|
| 199 |
+
#resized = cv2.resize(square_crop, (224, 224))
|
| 200 |
+
#frames.append(resized)
|
| 201 |
+
|
| 202 |
+
break # only one face per frame
|
| 203 |
+
|
| 204 |
+
cap.release()
|
| 205 |
+
|
| 206 |
+
# Save as video
|
| 207 |
+
tmpdir = tempfile.mkdtemp()
|
| 208 |
+
output_path = os.path.join(tmpdir, "face_only_video.mp4")
|
| 209 |
+
#clip = ImageSequenceClip([cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames], fps=25)
|
| 210 |
+
#clip = ImageSequenceClip([cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames], fps=fps)
|
| 211 |
+
clip = ImageSequenceClip(
|
| 212 |
+
[cv2.cvtColor(cv2.resize(f, (224, 224)), cv2.COLOR_BGR2RGB) for f in frames],
|
| 213 |
+
fps=fps
|
| 214 |
+
).fx(vfx.flip_vertical)
|
| 215 |
+
clip.write_videofile(output_path, codec="libx264", audio=False, fps=25)
|
| 216 |
+
|
| 217 |
+
# Save audio from original, resampled to 16kHz
|
| 218 |
+
audio_path = os.path.join(tmpdir, "audio_16k.wav")
|
| 219 |
+
|
| 220 |
+
# Extract audio using ffmpeg-python (more robust than moviepy)
|
| 221 |
+
ffmpeg.input(video_file).output(
|
| 222 |
+
audio_path,
|
| 223 |
+
ar=16000, # resample to 16k
|
| 224 |
+
ac=1, # mono
|
| 225 |
+
format='wav',
|
| 226 |
+
vn=None # no video
|
| 227 |
+
).run(overwrite_output=True)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# ------------------------------- #
|
| 233 |
+
# AVSE models
|
| 234 |
+
|
| 235 |
+
enhanced_audio_path = run_avse_inference(output_path, audio_path)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
return output_path, enhanced_audio_path
|
| 239 |
+
#return output_path, audio_path
|
| 240 |
+
|
| 241 |
+
iface = gr.Interface(
|
| 242 |
+
fn=extract_faces,
|
| 243 |
+
inputs=gr.Video(label="Upload or record your video"),
|
| 244 |
+
outputs=[
|
| 245 |
+
gr.Video(label="Detected Face Only Video"),
|
| 246 |
+
#gr.Audio(label="Extracted Audio (16kHz)", type="filepath"),
|
| 247 |
+
gr.Audio(label="Enhanced Audio", type="filepath")
|
| 248 |
+
],
|
| 249 |
+
title="Face Detector",
|
| 250 |
+
description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio."
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
iface.launch()
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
ckpt = "ckpts/SEMamba_advanced.pth"
|
| 258 |
+
cfg_f = "recipes/SEMamba_advanced.yaml"
|
| 259 |
+
|
| 260 |
+
# load config
|
| 261 |
+
with open(cfg_f, 'r') as f:
|
| 262 |
+
cfg = yaml.safe_load(f)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 266 |
+
device = "cuda"
|
| 267 |
+
model = SEMamba(cfg).to(device)
|
| 268 |
+
#sdict = torch.load(ckpt, map_location=device)
|
| 269 |
+
#model.load_state_dict(sdict["generator"])
|
| 270 |
+
#model.eval()
|
| 271 |
+
|
| 272 |
+
@spaces.GPU
|
| 273 |
+
def enhance(filepath, model_name):
|
| 274 |
+
# Load model based on selection
|
| 275 |
+
ckpt_path = {
|
| 276 |
+
"VCTK-Demand": "ckpts/SEMamba_advanced.pth",
|
| 277 |
+
"VCTK+DNS": "ckpts/vd.pth"
|
| 278 |
+
}[model_name]
|
| 279 |
+
|
| 280 |
+
print("Loading:", ckpt_path)
|
| 281 |
+
model.load_state_dict(torch.load(ckpt_path, map_location=device)["generator"])
|
| 282 |
+
model.eval()
|
| 283 |
+
with torch.no_grad():
|
| 284 |
+
# load & resample
|
| 285 |
+
wav, orig_sr = librosa.load(filepath, sr=None)
|
| 286 |
+
noisy_wav = wav.copy()
|
| 287 |
+
if orig_sr != 16000:
|
| 288 |
+
wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000)
|
| 289 |
+
x = torch.from_numpy(wav).float().to(device)
|
| 290 |
+
norm = torch.sqrt(len(x)/torch.sum(x**2))
|
| 291 |
+
#x = (x * norm).unsqueeze(0)
|
| 292 |
+
x = (x * norm)
|
| 293 |
+
|
| 294 |
+
# split into 4s segments (64000 samples)
|
| 295 |
+
segment_len = 4 * 16000
|
| 296 |
+
chunks = x.split(segment_len)
|
| 297 |
+
enhanced_chunks = []
|
| 298 |
+
|
| 299 |
+
for chunk in chunks:
|
| 300 |
+
if len(chunk) < segment_len:
|
| 301 |
+
#pad = torch.zeros(segment_len - len(chunk), device=chunk.device)
|
| 302 |
+
pad = (torch.randn(segment_len - len(chunk), device=chunk.device) * 1e-4)
|
| 303 |
+
chunk = torch.cat([chunk, pad])
|
| 304 |
+
chunk = chunk.unsqueeze(0)
|
| 305 |
+
|
| 306 |
+
amp, pha, _ = mag_phase_stft(chunk, 400, 100, 400, 0.3)
|
| 307 |
+
amp2, pha2, _ = model(amp, pha)
|
| 308 |
+
out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
|
| 309 |
+
out = (out / norm).squeeze(0)
|
| 310 |
+
enhanced_chunks.append(out)
|
| 311 |
+
|
| 312 |
+
out = torch.cat(enhanced_chunks)[:len(x)].cpu().numpy() # trim padding
|
| 313 |
+
|
| 314 |
+
# back to original rate
|
| 315 |
+
if orig_sr != 16000:
|
| 316 |
+
out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr)
|
| 317 |
+
|
| 318 |
+
# Normalize
|
| 319 |
+
peak = np.max(np.abs(out))
|
| 320 |
+
if peak > 0.05:
|
| 321 |
+
out = out / peak * 0.85
|
| 322 |
+
|
| 323 |
+
# write file
|
| 324 |
+
sf.write("enhanced.wav", out, orig_sr)
|
| 325 |
+
|
| 326 |
+
# spectrograms
|
| 327 |
+
fig, axs = plt.subplots(1, 2, figsize=(16, 4))
|
| 328 |
+
|
| 329 |
+
# noisy
|
| 330 |
+
D_noisy = librosa.stft(noisy_wav, n_fft=512, hop_length=256)
|
| 331 |
+
S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max)
|
| 332 |
+
librosa.display.specshow(S_noisy, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[0], vmax=0)
|
| 333 |
+
axs[0].set_title("Noisy Spectrogram")
|
| 334 |
+
|
| 335 |
+
# enhanced
|
| 336 |
+
D_clean = librosa.stft(out, n_fft=512, hop_length=256)
|
| 337 |
+
S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max)
|
| 338 |
+
librosa.display.specshow(S_clean, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
|
| 339 |
+
#librosa.display.specshow(S_clean, sr=16000, hop_length=512, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
|
| 340 |
+
axs[1].set_title("Enhanced Spectrogram")
|
| 341 |
+
|
| 342 |
+
plt.tight_layout()
|
| 343 |
+
|
| 344 |
+
return "enhanced.wav", fig
|
| 345 |
+
|
| 346 |
+
#with gr.Blocks() as demo:
|
| 347 |
+
# gr.Markdown(ABOUT)
|
| 348 |
+
# input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
|
| 349 |
+
# enhance_btn = gr.Button("Enhance")
|
| 350 |
+
# output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
|
| 351 |
+
# plot_output = gr.Plot(label="Spectrograms")
|
| 352 |
+
#
|
| 353 |
+
# enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output])
|
| 354 |
+
#
|
| 355 |
+
#demo.queue().launch()
|
| 356 |
+
|
| 357 |
+
with gr.Blocks() as demo:
|
| 358 |
+
gr.Markdown(ABOUT)
|
| 359 |
+
input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
|
| 360 |
+
model_choice = gr.Radio(
|
| 361 |
+
label="Choose Model (The use of VCTK+DNS is recommended)",
|
| 362 |
+
choices=["VCTK-Demand", "VCTK+DNS"],
|
| 363 |
+
value="VCTK-Demand"
|
| 364 |
+
)
|
| 365 |
+
enhance_btn = gr.Button("Enhance")
|
| 366 |
+
output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
|
| 367 |
+
plot_output = gr.Plot(label="Spectrograms")
|
| 368 |
+
|
| 369 |
+
enhance_btn.click(
|
| 370 |
+
fn=enhance,
|
| 371 |
+
inputs=[input_audio, model_choice],
|
| 372 |
+
outputs=[output_audio, plot_output]
|
| 373 |
+
)
|
| 374 |
+
gr.Markdown("**Note**: The current models are trained on 16kHz audio. Therefore, any input audio not sampled at 16kHz will be automatically resampled before enhancement.")
|
| 375 |
+
|
| 376 |
+
demo.queue().launch()
|
mamba_ssm/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
mamba_ssm/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "1.2.2"
|
| 2 |
+
|
| 3 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
| 4 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
| 5 |
+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
mamba_ssm/models/__init__.py
ADDED
|
File without changes
|
mamba_ssm/models/config_mamba.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class MambaConfig:
|
| 6 |
+
|
| 7 |
+
d_model: int = 2560
|
| 8 |
+
n_layer: int = 64
|
| 9 |
+
vocab_size: int = 50277
|
| 10 |
+
ssm_cfg: dict = field(default_factory=dict)
|
| 11 |
+
rms_norm: bool = True
|
| 12 |
+
residual_in_fp32: bool = True
|
| 13 |
+
fused_add_norm: bool = True
|
| 14 |
+
pad_vocab_size_multiple: int = 8
|
| 15 |
+
tie_embeddings: bool = True
|
mamba_ssm/models/mixer_seq_simple.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from functools import partial
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
from collections import namedtuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from mamba_ssm.models.config_mamba import MambaConfig
|
| 14 |
+
from mamba_ssm.modules.mamba_simple import Mamba, Block
|
| 15 |
+
from mamba_ssm.utils.generation import GenerationMixin
|
| 16 |
+
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 20 |
+
except ImportError:
|
| 21 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def create_block(
|
| 25 |
+
d_model,
|
| 26 |
+
ssm_cfg=None,
|
| 27 |
+
norm_epsilon=1e-5,
|
| 28 |
+
rms_norm=False,
|
| 29 |
+
residual_in_fp32=False,
|
| 30 |
+
fused_add_norm=False,
|
| 31 |
+
layer_idx=None,
|
| 32 |
+
device=None,
|
| 33 |
+
dtype=None,
|
| 34 |
+
):
|
| 35 |
+
if ssm_cfg is None:
|
| 36 |
+
ssm_cfg = {}
|
| 37 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 38 |
+
mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
|
| 39 |
+
norm_cls = partial(
|
| 40 |
+
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
| 41 |
+
)
|
| 42 |
+
block = Block(
|
| 43 |
+
d_model,
|
| 44 |
+
mixer_cls,
|
| 45 |
+
norm_cls=norm_cls,
|
| 46 |
+
fused_add_norm=fused_add_norm,
|
| 47 |
+
residual_in_fp32=residual_in_fp32,
|
| 48 |
+
)
|
| 49 |
+
block.layer_idx = layer_idx
|
| 50 |
+
return block
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
| 54 |
+
def _init_weights(
|
| 55 |
+
module,
|
| 56 |
+
n_layer,
|
| 57 |
+
initializer_range=0.02, # Now only used for embedding layer.
|
| 58 |
+
rescale_prenorm_residual=True,
|
| 59 |
+
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
| 60 |
+
):
|
| 61 |
+
if isinstance(module, nn.Linear):
|
| 62 |
+
if module.bias is not None:
|
| 63 |
+
if not getattr(module.bias, "_no_reinit", False):
|
| 64 |
+
nn.init.zeros_(module.bias)
|
| 65 |
+
elif isinstance(module, nn.Embedding):
|
| 66 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
| 67 |
+
|
| 68 |
+
if rescale_prenorm_residual:
|
| 69 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
| 70 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
| 71 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
| 72 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
| 73 |
+
#
|
| 74 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
| 75 |
+
for name, p in module.named_parameters():
|
| 76 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
| 77 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
| 78 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
| 79 |
+
# We need to reinit p since this code could be called multiple times
|
| 80 |
+
# Having just p *= scale would repeatedly scale it down
|
| 81 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class MixerModel(nn.Module):
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
d_model: int,
|
| 90 |
+
n_layer: int,
|
| 91 |
+
vocab_size: int,
|
| 92 |
+
ssm_cfg=None,
|
| 93 |
+
norm_epsilon: float = 1e-5,
|
| 94 |
+
rms_norm: bool = False,
|
| 95 |
+
initializer_cfg=None,
|
| 96 |
+
fused_add_norm=False,
|
| 97 |
+
residual_in_fp32=False,
|
| 98 |
+
device=None,
|
| 99 |
+
dtype=None,
|
| 100 |
+
) -> None:
|
| 101 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 104 |
+
|
| 105 |
+
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
| 106 |
+
|
| 107 |
+
# We change the order of residual and layer norm:
|
| 108 |
+
# Instead of LN -> Attn / MLP -> Add, we do:
|
| 109 |
+
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
| 110 |
+
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
| 111 |
+
# This is for performance reason: we can fuse add + layer_norm.
|
| 112 |
+
self.fused_add_norm = fused_add_norm
|
| 113 |
+
if self.fused_add_norm:
|
| 114 |
+
if layer_norm_fn is None or rms_norm_fn is None:
|
| 115 |
+
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
| 116 |
+
|
| 117 |
+
self.layers = nn.ModuleList(
|
| 118 |
+
[
|
| 119 |
+
create_block(
|
| 120 |
+
d_model,
|
| 121 |
+
ssm_cfg=ssm_cfg,
|
| 122 |
+
norm_epsilon=norm_epsilon,
|
| 123 |
+
rms_norm=rms_norm,
|
| 124 |
+
residual_in_fp32=residual_in_fp32,
|
| 125 |
+
fused_add_norm=fused_add_norm,
|
| 126 |
+
layer_idx=i,
|
| 127 |
+
**factory_kwargs,
|
| 128 |
+
)
|
| 129 |
+
for i in range(n_layer)
|
| 130 |
+
]
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
| 134 |
+
d_model, eps=norm_epsilon, **factory_kwargs
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
self.apply(
|
| 138 |
+
partial(
|
| 139 |
+
_init_weights,
|
| 140 |
+
n_layer=n_layer,
|
| 141 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 146 |
+
return {
|
| 147 |
+
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
| 148 |
+
for i, layer in enumerate(self.layers)
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
def forward(self, input_ids, inference_params=None):
|
| 152 |
+
hidden_states = self.embedding(input_ids)
|
| 153 |
+
residual = None
|
| 154 |
+
for layer in self.layers:
|
| 155 |
+
hidden_states, residual = layer(
|
| 156 |
+
hidden_states, residual, inference_params=inference_params
|
| 157 |
+
)
|
| 158 |
+
if not self.fused_add_norm:
|
| 159 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
| 160 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
| 161 |
+
else:
|
| 162 |
+
# Set prenorm=False here since we don't need the residual
|
| 163 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
|
| 164 |
+
hidden_states = fused_add_norm_fn(
|
| 165 |
+
hidden_states,
|
| 166 |
+
self.norm_f.weight,
|
| 167 |
+
self.norm_f.bias,
|
| 168 |
+
eps=self.norm_f.eps,
|
| 169 |
+
residual=residual,
|
| 170 |
+
prenorm=False,
|
| 171 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 172 |
+
)
|
| 173 |
+
return hidden_states
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
| 177 |
+
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
config: MambaConfig,
|
| 181 |
+
initializer_cfg=None,
|
| 182 |
+
device=None,
|
| 183 |
+
dtype=None,
|
| 184 |
+
) -> None:
|
| 185 |
+
self.config = config
|
| 186 |
+
d_model = config.d_model
|
| 187 |
+
n_layer = config.n_layer
|
| 188 |
+
vocab_size = config.vocab_size
|
| 189 |
+
ssm_cfg = config.ssm_cfg
|
| 190 |
+
rms_norm = config.rms_norm
|
| 191 |
+
residual_in_fp32 = config.residual_in_fp32
|
| 192 |
+
fused_add_norm = config.fused_add_norm
|
| 193 |
+
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
| 194 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 195 |
+
|
| 196 |
+
super().__init__()
|
| 197 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
| 198 |
+
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
| 199 |
+
self.backbone = MixerModel(
|
| 200 |
+
d_model=d_model,
|
| 201 |
+
n_layer=n_layer,
|
| 202 |
+
vocab_size=vocab_size,
|
| 203 |
+
ssm_cfg=ssm_cfg,
|
| 204 |
+
rms_norm=rms_norm,
|
| 205 |
+
initializer_cfg=initializer_cfg,
|
| 206 |
+
fused_add_norm=fused_add_norm,
|
| 207 |
+
residual_in_fp32=residual_in_fp32,
|
| 208 |
+
**factory_kwargs,
|
| 209 |
+
)
|
| 210 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
| 211 |
+
|
| 212 |
+
# Initialize weights and apply final processing
|
| 213 |
+
self.apply(
|
| 214 |
+
partial(
|
| 215 |
+
_init_weights,
|
| 216 |
+
n_layer=n_layer,
|
| 217 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 218 |
+
)
|
| 219 |
+
)
|
| 220 |
+
self.tie_weights()
|
| 221 |
+
|
| 222 |
+
def tie_weights(self):
|
| 223 |
+
if self.config.tie_embeddings:
|
| 224 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
| 225 |
+
|
| 226 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 227 |
+
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
| 228 |
+
|
| 229 |
+
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
|
| 230 |
+
"""
|
| 231 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
| 232 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
| 233 |
+
"""
|
| 234 |
+
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
| 235 |
+
if num_last_tokens > 0:
|
| 236 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
| 237 |
+
lm_logits = self.lm_head(hidden_states)
|
| 238 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
| 239 |
+
return CausalLMOutput(logits=lm_logits)
|
| 240 |
+
|
| 241 |
+
@classmethod
|
| 242 |
+
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
| 243 |
+
config_data = load_config_hf(pretrained_model_name)
|
| 244 |
+
config = MambaConfig(**config_data)
|
| 245 |
+
model = cls(config, device=device, dtype=dtype, **kwargs)
|
| 246 |
+
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
|
| 247 |
+
return model
|
| 248 |
+
|
| 249 |
+
def save_pretrained(self, save_directory):
|
| 250 |
+
"""
|
| 251 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
| 252 |
+
Save the model and its configuration file to a directory.
|
| 253 |
+
"""
|
| 254 |
+
# Ensure save_directory exists
|
| 255 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 256 |
+
|
| 257 |
+
# Save the model's state_dict
|
| 258 |
+
model_path = os.path.join(save_directory, 'pytorch_model.bin')
|
| 259 |
+
torch.save(self.state_dict(), model_path)
|
| 260 |
+
|
| 261 |
+
# Save the configuration of the model
|
| 262 |
+
config_path = os.path.join(save_directory, 'config.json')
|
| 263 |
+
with open(config_path, 'w') as f:
|
| 264 |
+
json.dump(self.config.__dict__, f)
|
mamba_ssm/modules/__init__.py
ADDED
|
File without changes
|
mamba_ssm/modules/mamba_simple.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
from einops import rearrange, repeat
|
| 12 |
+
|
| 13 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
| 17 |
+
except ImportError:
|
| 18 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 22 |
+
except ImportError:
|
| 23 |
+
selective_state_update = None
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 27 |
+
except ImportError:
|
| 28 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Mamba(nn.Module):
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
d_model,
|
| 35 |
+
d_state=16,
|
| 36 |
+
d_conv=4,
|
| 37 |
+
expand=2,
|
| 38 |
+
dt_rank="auto",
|
| 39 |
+
dt_min=0.001,
|
| 40 |
+
dt_max=0.1,
|
| 41 |
+
dt_init="random",
|
| 42 |
+
dt_scale=1.0,
|
| 43 |
+
dt_init_floor=1e-4,
|
| 44 |
+
conv_bias=True,
|
| 45 |
+
bias=False,
|
| 46 |
+
use_fast_path=True, # Fused kernel options
|
| 47 |
+
layer_idx=None,
|
| 48 |
+
device=None,
|
| 49 |
+
dtype=None,
|
| 50 |
+
):
|
| 51 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.d_model = d_model
|
| 54 |
+
self.d_state = d_state
|
| 55 |
+
self.d_conv = d_conv
|
| 56 |
+
self.expand = expand
|
| 57 |
+
self.d_inner = int(self.expand * self.d_model)
|
| 58 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
| 59 |
+
self.use_fast_path = use_fast_path
|
| 60 |
+
self.layer_idx = layer_idx
|
| 61 |
+
|
| 62 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
| 63 |
+
|
| 64 |
+
self.conv1d = nn.Conv1d(
|
| 65 |
+
in_channels=self.d_inner,
|
| 66 |
+
out_channels=self.d_inner,
|
| 67 |
+
bias=conv_bias,
|
| 68 |
+
kernel_size=d_conv,
|
| 69 |
+
groups=self.d_inner,
|
| 70 |
+
padding=d_conv - 1,
|
| 71 |
+
**factory_kwargs,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.activation = "silu"
|
| 75 |
+
self.act = nn.SiLU()
|
| 76 |
+
|
| 77 |
+
self.x_proj = nn.Linear(
|
| 78 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
| 79 |
+
)
|
| 80 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
| 81 |
+
|
| 82 |
+
# Initialize special dt projection to preserve variance at initialization
|
| 83 |
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
| 84 |
+
if dt_init == "constant":
|
| 85 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
| 86 |
+
elif dt_init == "random":
|
| 87 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
| 88 |
+
else:
|
| 89 |
+
raise NotImplementedError
|
| 90 |
+
|
| 91 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
| 92 |
+
dt = torch.exp(
|
| 93 |
+
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
| 94 |
+
+ math.log(dt_min)
|
| 95 |
+
).clamp(min=dt_init_floor)
|
| 96 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
| 97 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
self.dt_proj.bias.copy_(inv_dt)
|
| 100 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
| 101 |
+
self.dt_proj.bias._no_reinit = True
|
| 102 |
+
|
| 103 |
+
# S4D real initialization
|
| 104 |
+
A = repeat(
|
| 105 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
| 106 |
+
"n -> d n",
|
| 107 |
+
d=self.d_inner,
|
| 108 |
+
).contiguous()
|
| 109 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
| 110 |
+
self.A_log = nn.Parameter(A_log)
|
| 111 |
+
self.A_log._no_weight_decay = True
|
| 112 |
+
|
| 113 |
+
# D "skip" parameter
|
| 114 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
| 115 |
+
self.D._no_weight_decay = True
|
| 116 |
+
|
| 117 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
| 118 |
+
|
| 119 |
+
def forward(self, hidden_states, inference_params=None):
|
| 120 |
+
"""
|
| 121 |
+
hidden_states: (B, L, D)
|
| 122 |
+
Returns: same shape as hidden_states
|
| 123 |
+
"""
|
| 124 |
+
batch, seqlen, dim = hidden_states.shape
|
| 125 |
+
|
| 126 |
+
conv_state, ssm_state = None, None
|
| 127 |
+
if inference_params is not None:
|
| 128 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
| 129 |
+
if inference_params.seqlen_offset > 0:
|
| 130 |
+
# The states are updated inplace
|
| 131 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
| 132 |
+
return out
|
| 133 |
+
|
| 134 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
| 135 |
+
xz = rearrange(
|
| 136 |
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
| 137 |
+
"d (b l) -> b d l",
|
| 138 |
+
l=seqlen,
|
| 139 |
+
)
|
| 140 |
+
if self.in_proj.bias is not None:
|
| 141 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
| 142 |
+
|
| 143 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 144 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
| 145 |
+
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
|
| 146 |
+
out = mamba_inner_fn(
|
| 147 |
+
xz,
|
| 148 |
+
self.conv1d.weight,
|
| 149 |
+
self.conv1d.bias,
|
| 150 |
+
self.x_proj.weight,
|
| 151 |
+
self.dt_proj.weight,
|
| 152 |
+
self.out_proj.weight,
|
| 153 |
+
self.out_proj.bias,
|
| 154 |
+
A,
|
| 155 |
+
None, # input-dependent B
|
| 156 |
+
None, # input-dependent C
|
| 157 |
+
self.D.float(),
|
| 158 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 159 |
+
delta_softplus=True,
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
x, z = xz.chunk(2, dim=1)
|
| 163 |
+
# Compute short convolution
|
| 164 |
+
if conv_state is not None:
|
| 165 |
+
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
| 166 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
| 167 |
+
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
|
| 168 |
+
if causal_conv1d_fn is None:
|
| 169 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
| 170 |
+
else:
|
| 171 |
+
assert self.activation in ["silu", "swish"]
|
| 172 |
+
x = causal_conv1d_fn(
|
| 173 |
+
x=x,
|
| 174 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 175 |
+
bias=self.conv1d.bias,
|
| 176 |
+
activation=self.activation,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# We're careful here about the layout, to avoid extra transposes.
|
| 180 |
+
# We want dt to have d as the slowest moving dimension
|
| 181 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 182 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
| 183 |
+
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 184 |
+
dt = self.dt_proj.weight @ dt.t()
|
| 185 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
| 186 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
| 187 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
| 188 |
+
assert self.activation in ["silu", "swish"]
|
| 189 |
+
y = selective_scan_fn(
|
| 190 |
+
x,
|
| 191 |
+
dt,
|
| 192 |
+
A,
|
| 193 |
+
B,
|
| 194 |
+
C,
|
| 195 |
+
self.D.float(),
|
| 196 |
+
z=z,
|
| 197 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 198 |
+
delta_softplus=True,
|
| 199 |
+
return_last_state=ssm_state is not None,
|
| 200 |
+
)
|
| 201 |
+
if ssm_state is not None:
|
| 202 |
+
y, last_state = y
|
| 203 |
+
ssm_state.copy_(last_state)
|
| 204 |
+
y = rearrange(y, "b d l -> b l d")
|
| 205 |
+
out = self.out_proj(y)
|
| 206 |
+
return out
|
| 207 |
+
|
| 208 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
| 209 |
+
dtype = hidden_states.dtype
|
| 210 |
+
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
| 211 |
+
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
| 212 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
| 213 |
+
|
| 214 |
+
# Conv step
|
| 215 |
+
if causal_conv1d_update is None:
|
| 216 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
| 217 |
+
conv_state[:, :, -1] = x
|
| 218 |
+
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
| 219 |
+
if self.conv1d.bias is not None:
|
| 220 |
+
x = x + self.conv1d.bias
|
| 221 |
+
x = self.act(x).to(dtype=dtype)
|
| 222 |
+
else:
|
| 223 |
+
x = causal_conv1d_update(
|
| 224 |
+
x,
|
| 225 |
+
conv_state,
|
| 226 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 227 |
+
self.conv1d.bias,
|
| 228 |
+
self.activation,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
| 232 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 233 |
+
# Don't add dt_bias here
|
| 234 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
| 235 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 236 |
+
|
| 237 |
+
# SSM step
|
| 238 |
+
if selective_state_update is None:
|
| 239 |
+
# Discretize A and B
|
| 240 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
| 241 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
| 242 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
| 243 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
| 244 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
| 245 |
+
y = y + self.D.to(dtype) * x
|
| 246 |
+
y = y * self.act(z) # (B D)
|
| 247 |
+
else:
|
| 248 |
+
y = selective_state_update(
|
| 249 |
+
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
out = self.out_proj(y)
|
| 253 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
| 254 |
+
|
| 255 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 256 |
+
device = self.out_proj.weight.device
|
| 257 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
| 258 |
+
conv_state = torch.zeros(
|
| 259 |
+
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
|
| 260 |
+
)
|
| 261 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
| 262 |
+
# ssm_dtype = torch.float32
|
| 263 |
+
ssm_state = torch.zeros(
|
| 264 |
+
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
|
| 265 |
+
)
|
| 266 |
+
return conv_state, ssm_state
|
| 267 |
+
|
| 268 |
+
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
| 269 |
+
assert self.layer_idx is not None
|
| 270 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
| 271 |
+
batch_shape = (batch_size,)
|
| 272 |
+
conv_state = torch.zeros(
|
| 273 |
+
batch_size,
|
| 274 |
+
self.d_model * self.expand,
|
| 275 |
+
self.d_conv,
|
| 276 |
+
device=self.conv1d.weight.device,
|
| 277 |
+
dtype=self.conv1d.weight.dtype,
|
| 278 |
+
)
|
| 279 |
+
ssm_state = torch.zeros(
|
| 280 |
+
batch_size,
|
| 281 |
+
self.d_model * self.expand,
|
| 282 |
+
self.d_state,
|
| 283 |
+
device=self.dt_proj.weight.device,
|
| 284 |
+
dtype=self.dt_proj.weight.dtype,
|
| 285 |
+
# dtype=torch.float32,
|
| 286 |
+
)
|
| 287 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
| 288 |
+
else:
|
| 289 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
| 290 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
| 291 |
+
if initialize_states:
|
| 292 |
+
conv_state.zero_()
|
| 293 |
+
ssm_state.zero_()
|
| 294 |
+
return conv_state, ssm_state
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class Block(nn.Module):
|
| 298 |
+
def __init__(
|
| 299 |
+
self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
| 300 |
+
):
|
| 301 |
+
"""
|
| 302 |
+
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
| 303 |
+
|
| 304 |
+
This Block has a slightly different structure compared to a regular
|
| 305 |
+
prenorm Transformer block.
|
| 306 |
+
The standard block is: LN -> MHA/MLP -> Add.
|
| 307 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
| 308 |
+
Here we have: Add -> LN -> Mixer, returning both
|
| 309 |
+
the hidden_states (output of the mixer) and the residual.
|
| 310 |
+
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
| 311 |
+
The residual needs to be provided (except for the very first block).
|
| 312 |
+
"""
|
| 313 |
+
super().__init__()
|
| 314 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 315 |
+
self.fused_add_norm = fused_add_norm
|
| 316 |
+
self.mixer = mixer_cls(dim)
|
| 317 |
+
self.norm = norm_cls(dim)
|
| 318 |
+
if self.fused_add_norm:
|
| 319 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
| 320 |
+
assert isinstance(
|
| 321 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
| 322 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
| 323 |
+
|
| 324 |
+
def forward(
|
| 325 |
+
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
| 326 |
+
):
|
| 327 |
+
r"""Pass the input through the encoder layer.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
hidden_states: the sequence to the encoder layer (required).
|
| 331 |
+
residual: hidden_states = Mixer(LN(residual))
|
| 332 |
+
"""
|
| 333 |
+
if not self.fused_add_norm:
|
| 334 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
| 335 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
| 336 |
+
if self.residual_in_fp32:
|
| 337 |
+
residual = residual.to(torch.float32)
|
| 338 |
+
else:
|
| 339 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
|
| 340 |
+
hidden_states, residual = fused_add_norm_fn(
|
| 341 |
+
hidden_states,
|
| 342 |
+
self.norm.weight,
|
| 343 |
+
self.norm.bias,
|
| 344 |
+
residual=residual,
|
| 345 |
+
prenorm=True,
|
| 346 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 347 |
+
eps=self.norm.eps,
|
| 348 |
+
)
|
| 349 |
+
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
| 350 |
+
return hidden_states, residual
|
| 351 |
+
|
| 352 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 353 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
mamba_ssm/ops/__init__.py
ADDED
|
File without changes
|
mamba_ssm/ops/selective_scan_interface.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
| 6 |
+
|
| 7 |
+
from einops import rearrange, repeat
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from causal_conv1d import causal_conv1d_fn
|
| 11 |
+
import causal_conv1d_cuda
|
| 12 |
+
except ImportError:
|
| 13 |
+
causal_conv1d_fn = None
|
| 14 |
+
causal_conv1d_cuda = None
|
| 15 |
+
|
| 16 |
+
import selective_scan_cuda
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SelectiveScanFn(torch.autograd.Function):
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
| 23 |
+
return_last_state=False):
|
| 24 |
+
if u.stride(-1) != 1:
|
| 25 |
+
u = u.contiguous()
|
| 26 |
+
if delta.stride(-1) != 1:
|
| 27 |
+
delta = delta.contiguous()
|
| 28 |
+
if D is not None:
|
| 29 |
+
D = D.contiguous()
|
| 30 |
+
if B.stride(-1) != 1:
|
| 31 |
+
B = B.contiguous()
|
| 32 |
+
if C.stride(-1) != 1:
|
| 33 |
+
C = C.contiguous()
|
| 34 |
+
if z is not None and z.stride(-1) != 1:
|
| 35 |
+
z = z.contiguous()
|
| 36 |
+
if B.dim() == 3:
|
| 37 |
+
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
| 38 |
+
ctx.squeeze_B = True
|
| 39 |
+
if C.dim() == 3:
|
| 40 |
+
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
| 41 |
+
ctx.squeeze_C = True
|
| 42 |
+
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
|
| 43 |
+
ctx.delta_softplus = delta_softplus
|
| 44 |
+
ctx.has_z = z is not None
|
| 45 |
+
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
| 46 |
+
if not ctx.has_z:
|
| 47 |
+
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
| 48 |
+
return out if not return_last_state else (out, last_state)
|
| 49 |
+
else:
|
| 50 |
+
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
| 51 |
+
out_z = rest[0]
|
| 52 |
+
return out_z if not return_last_state else (out_z, last_state)
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def backward(ctx, dout, *args):
|
| 56 |
+
if not ctx.has_z:
|
| 57 |
+
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
| 58 |
+
z = None
|
| 59 |
+
out = None
|
| 60 |
+
else:
|
| 61 |
+
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
| 62 |
+
if dout.stride(-1) != 1:
|
| 63 |
+
dout = dout.contiguous()
|
| 64 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
| 65 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
| 66 |
+
# Here we just pass in None and dz will be allocated in the C++ code.
|
| 67 |
+
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
|
| 68 |
+
u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
|
| 69 |
+
False # option to recompute out_z, not used here
|
| 70 |
+
)
|
| 71 |
+
dz = rest[0] if ctx.has_z else None
|
| 72 |
+
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
| 73 |
+
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
| 74 |
+
return (du, ddelta, dA, dB, dC,
|
| 75 |
+
dD if D is not None else None,
|
| 76 |
+
dz,
|
| 77 |
+
ddelta_bias if delta_bias is not None else None,
|
| 78 |
+
None,
|
| 79 |
+
None)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
| 83 |
+
return_last_state=False):
|
| 84 |
+
"""if return_last_state is True, returns (out, last_state)
|
| 85 |
+
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
| 86 |
+
not considered in the backward pass.
|
| 87 |
+
"""
|
| 88 |
+
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
| 92 |
+
return_last_state=False):
|
| 93 |
+
"""
|
| 94 |
+
u: r(B D L)
|
| 95 |
+
delta: r(B D L)
|
| 96 |
+
A: c(D N) or r(D N)
|
| 97 |
+
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
| 98 |
+
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
| 99 |
+
D: r(D)
|
| 100 |
+
z: r(B D L)
|
| 101 |
+
delta_bias: r(D), fp32
|
| 102 |
+
|
| 103 |
+
out: r(B D L)
|
| 104 |
+
last_state (optional): r(B D dstate) or c(B D dstate)
|
| 105 |
+
"""
|
| 106 |
+
dtype_in = u.dtype
|
| 107 |
+
u = u.float()
|
| 108 |
+
delta = delta.float()
|
| 109 |
+
if delta_bias is not None:
|
| 110 |
+
delta = delta + delta_bias[..., None].float()
|
| 111 |
+
if delta_softplus:
|
| 112 |
+
delta = F.softplus(delta)
|
| 113 |
+
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
| 114 |
+
is_variable_B = B.dim() >= 3
|
| 115 |
+
is_variable_C = C.dim() >= 3
|
| 116 |
+
if A.is_complex():
|
| 117 |
+
if is_variable_B:
|
| 118 |
+
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
| 119 |
+
if is_variable_C:
|
| 120 |
+
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
| 121 |
+
else:
|
| 122 |
+
B = B.float()
|
| 123 |
+
C = C.float()
|
| 124 |
+
x = A.new_zeros((batch, dim, dstate))
|
| 125 |
+
ys = []
|
| 126 |
+
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
| 127 |
+
if not is_variable_B:
|
| 128 |
+
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
| 129 |
+
else:
|
| 130 |
+
if B.dim() == 3:
|
| 131 |
+
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
| 132 |
+
else:
|
| 133 |
+
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
| 134 |
+
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
| 135 |
+
if is_variable_C and C.dim() == 4:
|
| 136 |
+
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
| 137 |
+
last_state = None
|
| 138 |
+
for i in range(u.shape[2]):
|
| 139 |
+
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
| 140 |
+
if not is_variable_C:
|
| 141 |
+
y = torch.einsum('bdn,dn->bd', x, C)
|
| 142 |
+
else:
|
| 143 |
+
if C.dim() == 3:
|
| 144 |
+
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
| 145 |
+
else:
|
| 146 |
+
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
| 147 |
+
if i == u.shape[2] - 1:
|
| 148 |
+
last_state = x
|
| 149 |
+
if y.is_complex():
|
| 150 |
+
y = y.real * 2
|
| 151 |
+
ys.append(y)
|
| 152 |
+
y = torch.stack(ys, dim=2) # (batch dim L)
|
| 153 |
+
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
| 154 |
+
if z is not None:
|
| 155 |
+
out = out * F.silu(z)
|
| 156 |
+
out = out.to(dtype=dtype_in)
|
| 157 |
+
return out if not return_last_state else (out, last_state)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class MambaInnerFn(torch.autograd.Function):
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
@custom_fwd
|
| 164 |
+
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| 165 |
+
out_proj_weight, out_proj_bias,
|
| 166 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
| 167 |
+
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
|
| 168 |
+
"""
|
| 169 |
+
xz: (batch, dim, seqlen)
|
| 170 |
+
"""
|
| 171 |
+
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
| 172 |
+
assert checkpoint_lvl in [0, 1]
|
| 173 |
+
L = xz.shape[-1]
|
| 174 |
+
delta_rank = delta_proj_weight.shape[1]
|
| 175 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 176 |
+
if torch.is_autocast_enabled():
|
| 177 |
+
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 178 |
+
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 179 |
+
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 180 |
+
out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
| 181 |
+
if out_proj_bias is not None else None)
|
| 182 |
+
if xz.stride(-1) != 1:
|
| 183 |
+
xz = xz.contiguous()
|
| 184 |
+
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
| 185 |
+
x, z = xz.chunk(2, dim=1)
|
| 186 |
+
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
| 187 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
| 188 |
+
x, conv1d_weight, conv1d_bias, None, None, None, True
|
| 189 |
+
)
|
| 190 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
| 191 |
+
# We want delta to have d as the slowest moving dimension
|
| 192 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 193 |
+
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
| 194 |
+
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
|
| 195 |
+
ctx.is_variable_B = B is None
|
| 196 |
+
ctx.is_variable_C = C is None
|
| 197 |
+
ctx.B_proj_bias_is_None = B_proj_bias is None
|
| 198 |
+
ctx.C_proj_bias_is_None = C_proj_bias is None
|
| 199 |
+
if B is None: # variable B
|
| 200 |
+
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
|
| 201 |
+
if B_proj_bias is not None:
|
| 202 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
| 203 |
+
if not A.is_complex():
|
| 204 |
+
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 205 |
+
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 206 |
+
else:
|
| 207 |
+
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
| 208 |
+
else:
|
| 209 |
+
if B.stride(-1) != 1:
|
| 210 |
+
B = B.contiguous()
|
| 211 |
+
if C is None: # variable C
|
| 212 |
+
C = x_dbl[:, -d_state:] # (bl dstate)
|
| 213 |
+
if C_proj_bias is not None:
|
| 214 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
| 215 |
+
if not A.is_complex():
|
| 216 |
+
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 217 |
+
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 218 |
+
else:
|
| 219 |
+
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
| 220 |
+
else:
|
| 221 |
+
if C.stride(-1) != 1:
|
| 222 |
+
C = C.contiguous()
|
| 223 |
+
if D is not None:
|
| 224 |
+
D = D.contiguous()
|
| 225 |
+
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
|
| 226 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
| 227 |
+
)
|
| 228 |
+
ctx.delta_softplus = delta_softplus
|
| 229 |
+
ctx.out_proj_bias_is_None = out_proj_bias is None
|
| 230 |
+
ctx.checkpoint_lvl = checkpoint_lvl
|
| 231 |
+
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
|
| 232 |
+
conv1d_out, delta = None, None
|
| 233 |
+
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
|
| 234 |
+
delta_proj_weight, out_proj_weight, conv1d_out, delta,
|
| 235 |
+
A, B, C, D, delta_bias, scan_intermediates, out)
|
| 236 |
+
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
| 237 |
+
|
| 238 |
+
@staticmethod
|
| 239 |
+
@custom_bwd
|
| 240 |
+
def backward(ctx, dout):
|
| 241 |
+
# dout: (batch, seqlen, dim)
|
| 242 |
+
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
| 243 |
+
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
|
| 244 |
+
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
|
| 245 |
+
L = xz.shape[-1]
|
| 246 |
+
delta_rank = delta_proj_weight.shape[1]
|
| 247 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 248 |
+
x, z = xz.chunk(2, dim=1)
|
| 249 |
+
if dout.stride(-1) != 1:
|
| 250 |
+
dout = dout.contiguous()
|
| 251 |
+
if ctx.checkpoint_lvl == 1:
|
| 252 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
| 253 |
+
x, conv1d_weight, conv1d_bias, None, None, None, True
|
| 254 |
+
)
|
| 255 |
+
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
|
| 256 |
+
"d (b l) -> b d l", l = L)
|
| 257 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
| 258 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
| 259 |
+
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
| 260 |
+
dx, dz = dxz.chunk(2, dim=1)
|
| 261 |
+
dout = rearrange(dout, "b l e -> e (b l)")
|
| 262 |
+
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
| 263 |
+
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
|
| 264 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
|
| 265 |
+
ctx.delta_softplus,
|
| 266 |
+
True # option to recompute out_z
|
| 267 |
+
)
|
| 268 |
+
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
|
| 269 |
+
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
| 270 |
+
dD = dD if D is not None else None
|
| 271 |
+
dx_dbl = torch.empty_like(x_dbl)
|
| 272 |
+
dB_proj_bias = None
|
| 273 |
+
if ctx.is_variable_B:
|
| 274 |
+
if not A.is_complex():
|
| 275 |
+
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
| 276 |
+
else:
|
| 277 |
+
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
| 278 |
+
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
| 279 |
+
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
|
| 280 |
+
dB = None
|
| 281 |
+
dC_proj_bias = None
|
| 282 |
+
if ctx.is_variable_C:
|
| 283 |
+
if not A.is_complex():
|
| 284 |
+
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
| 285 |
+
else:
|
| 286 |
+
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
| 287 |
+
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
| 288 |
+
dx_dbl[:, -d_state:] = dC # (bl d)
|
| 289 |
+
dC = None
|
| 290 |
+
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
| 291 |
+
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
| 292 |
+
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
| 293 |
+
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
| 294 |
+
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
| 295 |
+
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
| 296 |
+
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
| 297 |
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
| 298 |
+
# backward of conv1d with the backward of chunk).
|
| 299 |
+
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
| 300 |
+
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
|
| 301 |
+
)
|
| 302 |
+
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
| 303 |
+
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
| 304 |
+
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
|
| 305 |
+
dout_proj_weight, dout_proj_bias,
|
| 306 |
+
dA, dB, dC, dD,
|
| 307 |
+
ddelta_bias if delta_bias is not None else None,
|
| 308 |
+
dB_proj_bias, dC_proj_bias, None)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def mamba_inner_fn(
|
| 312 |
+
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| 313 |
+
out_proj_weight, out_proj_bias,
|
| 314 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
| 315 |
+
C_proj_bias=None, delta_softplus=True
|
| 316 |
+
):
|
| 317 |
+
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| 318 |
+
out_proj_weight, out_proj_bias,
|
| 319 |
+
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def mamba_inner_ref(
|
| 323 |
+
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| 324 |
+
out_proj_weight, out_proj_bias,
|
| 325 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
| 326 |
+
C_proj_bias=None, delta_softplus=True
|
| 327 |
+
):
|
| 328 |
+
assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
|
| 329 |
+
L = xz.shape[-1]
|
| 330 |
+
delta_rank = delta_proj_weight.shape[1]
|
| 331 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 332 |
+
x, z = xz.chunk(2, dim=1)
|
| 333 |
+
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
|
| 334 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
| 335 |
+
# We want delta to have d as the slowest moving dimension
|
| 336 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 337 |
+
x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
| 338 |
+
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
| 339 |
+
delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
| 340 |
+
if B is None: # variable B
|
| 341 |
+
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
|
| 342 |
+
if B_proj_bias is not None:
|
| 343 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
| 344 |
+
if not A.is_complex():
|
| 345 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 346 |
+
else:
|
| 347 |
+
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
| 348 |
+
if C is None: # variable B
|
| 349 |
+
C = x_dbl[:, -d_state:] # (bl d)
|
| 350 |
+
if C_proj_bias is not None:
|
| 351 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
| 352 |
+
if not A.is_complex():
|
| 353 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 354 |
+
else:
|
| 355 |
+
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
| 356 |
+
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
|
| 357 |
+
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
mamba_ssm/ops/triton/__init__.py
ADDED
|
File without changes
|
mamba_ssm/ops/triton/layernorm.py
ADDED
|
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao.
|
| 2 |
+
# Implement residual + layer_norm / rms_norm.
|
| 3 |
+
|
| 4 |
+
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
| 5 |
+
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
| 6 |
+
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
| 7 |
+
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.cuda.amp import custom_fwd, custom_bwd
|
| 14 |
+
|
| 15 |
+
import triton
|
| 16 |
+
import triton.language as tl
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
| 20 |
+
dtype = x.dtype
|
| 21 |
+
if upcast:
|
| 22 |
+
weight = weight.float()
|
| 23 |
+
bias = bias.float() if bias is not None else None
|
| 24 |
+
if upcast:
|
| 25 |
+
x = x.float()
|
| 26 |
+
residual = residual.float() if residual is not None else residual
|
| 27 |
+
if residual is not None:
|
| 28 |
+
x = (x + residual).to(x.dtype)
|
| 29 |
+
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
|
| 30 |
+
dtype
|
| 31 |
+
)
|
| 32 |
+
return out if not prenorm else (out, x)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
| 36 |
+
dtype = x.dtype
|
| 37 |
+
if upcast:
|
| 38 |
+
weight = weight.float()
|
| 39 |
+
bias = bias.float() if bias is not None else None
|
| 40 |
+
if upcast:
|
| 41 |
+
x = x.float()
|
| 42 |
+
residual = residual.float() if residual is not None else residual
|
| 43 |
+
if residual is not None:
|
| 44 |
+
x = (x + residual).to(x.dtype)
|
| 45 |
+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
| 46 |
+
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
| 47 |
+
out = out.to(dtype)
|
| 48 |
+
return out if not prenorm else (out, x)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@triton.autotune(
|
| 52 |
+
configs=[
|
| 53 |
+
triton.Config({}, num_warps=1),
|
| 54 |
+
triton.Config({}, num_warps=2),
|
| 55 |
+
triton.Config({}, num_warps=4),
|
| 56 |
+
triton.Config({}, num_warps=8),
|
| 57 |
+
triton.Config({}, num_warps=16),
|
| 58 |
+
triton.Config({}, num_warps=32),
|
| 59 |
+
],
|
| 60 |
+
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
| 61 |
+
)
|
| 62 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 63 |
+
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
| 64 |
+
@triton.jit
|
| 65 |
+
def _layer_norm_fwd_1pass_kernel(
|
| 66 |
+
X, # pointer to the input
|
| 67 |
+
Y, # pointer to the output
|
| 68 |
+
W, # pointer to the weights
|
| 69 |
+
B, # pointer to the biases
|
| 70 |
+
RESIDUAL, # pointer to the residual
|
| 71 |
+
RESIDUAL_OUT, # pointer to the residual
|
| 72 |
+
Mean, # pointer to the mean
|
| 73 |
+
Rstd, # pointer to the 1/std
|
| 74 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 75 |
+
stride_y_row,
|
| 76 |
+
stride_res_row,
|
| 77 |
+
stride_res_out_row,
|
| 78 |
+
N, # number of columns in X
|
| 79 |
+
eps, # epsilon to avoid division by zero
|
| 80 |
+
IS_RMS_NORM: tl.constexpr,
|
| 81 |
+
BLOCK_N: tl.constexpr,
|
| 82 |
+
HAS_RESIDUAL: tl.constexpr,
|
| 83 |
+
STORE_RESIDUAL_OUT: tl.constexpr,
|
| 84 |
+
HAS_BIAS: tl.constexpr,
|
| 85 |
+
):
|
| 86 |
+
# Map the program id to the row of X and Y it should compute.
|
| 87 |
+
row = tl.program_id(0)
|
| 88 |
+
X += row * stride_x_row
|
| 89 |
+
Y += row * stride_y_row
|
| 90 |
+
if HAS_RESIDUAL:
|
| 91 |
+
RESIDUAL += row * stride_res_row
|
| 92 |
+
if STORE_RESIDUAL_OUT:
|
| 93 |
+
RESIDUAL_OUT += row * stride_res_out_row
|
| 94 |
+
# Compute mean and variance
|
| 95 |
+
cols = tl.arange(0, BLOCK_N)
|
| 96 |
+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 97 |
+
if HAS_RESIDUAL:
|
| 98 |
+
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 99 |
+
x += residual
|
| 100 |
+
if STORE_RESIDUAL_OUT:
|
| 101 |
+
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
| 102 |
+
if not IS_RMS_NORM:
|
| 103 |
+
mean = tl.sum(x, axis=0) / N
|
| 104 |
+
tl.store(Mean + row, mean)
|
| 105 |
+
xbar = tl.where(cols < N, x - mean, 0.0)
|
| 106 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
| 107 |
+
else:
|
| 108 |
+
xbar = tl.where(cols < N, x, 0.0)
|
| 109 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
| 110 |
+
rstd = 1 / tl.sqrt(var + eps)
|
| 111 |
+
tl.store(Rstd + row, rstd)
|
| 112 |
+
# Normalize and apply linear transformation
|
| 113 |
+
mask = cols < N
|
| 114 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 115 |
+
if HAS_BIAS:
|
| 116 |
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
| 117 |
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 118 |
+
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
| 119 |
+
# Write output
|
| 120 |
+
tl.store(Y + cols, y, mask=mask)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _layer_norm_fwd(
|
| 124 |
+
x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
|
| 125 |
+
):
|
| 126 |
+
if residual is not None:
|
| 127 |
+
residual_dtype = residual.dtype
|
| 128 |
+
M, N = x.shape
|
| 129 |
+
assert x.stride(-1) == 1
|
| 130 |
+
if residual is not None:
|
| 131 |
+
assert residual.stride(-1) == 1
|
| 132 |
+
assert residual.shape == (M, N)
|
| 133 |
+
assert weight.shape == (N,)
|
| 134 |
+
assert weight.stride(-1) == 1
|
| 135 |
+
if bias is not None:
|
| 136 |
+
assert bias.stride(-1) == 1
|
| 137 |
+
assert bias.shape == (N,)
|
| 138 |
+
# allocate output
|
| 139 |
+
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
| 140 |
+
assert y.stride(-1) == 1
|
| 141 |
+
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
|
| 142 |
+
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
|
| 143 |
+
assert residual_out.stride(-1) == 1
|
| 144 |
+
else:
|
| 145 |
+
residual_out = None
|
| 146 |
+
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
|
| 147 |
+
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
| 148 |
+
# Less than 64KB per feature: enqueue fused kernel
|
| 149 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 150 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 151 |
+
if N > BLOCK_N:
|
| 152 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 153 |
+
# heuristics for number of warps
|
| 154 |
+
with torch.cuda.device(x.device.index):
|
| 155 |
+
_layer_norm_fwd_1pass_kernel[(M,)](
|
| 156 |
+
x,
|
| 157 |
+
y,
|
| 158 |
+
weight,
|
| 159 |
+
bias,
|
| 160 |
+
residual,
|
| 161 |
+
residual_out,
|
| 162 |
+
mean,
|
| 163 |
+
rstd,
|
| 164 |
+
x.stride(0),
|
| 165 |
+
y.stride(0),
|
| 166 |
+
residual.stride(0) if residual is not None else 0,
|
| 167 |
+
residual_out.stride(0) if residual_out is not None else 0,
|
| 168 |
+
N,
|
| 169 |
+
eps,
|
| 170 |
+
is_rms_norm,
|
| 171 |
+
BLOCK_N,
|
| 172 |
+
residual is not None,
|
| 173 |
+
residual_out is not None,
|
| 174 |
+
bias is not None,
|
| 175 |
+
)
|
| 176 |
+
# residual_out is None if residual is None and residual_dtype == input_dtype
|
| 177 |
+
return y, mean, rstd, residual_out if residual_out is not None else x
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@triton.autotune(
|
| 181 |
+
configs=[
|
| 182 |
+
triton.Config({}, num_warps=1),
|
| 183 |
+
triton.Config({}, num_warps=2),
|
| 184 |
+
triton.Config({}, num_warps=4),
|
| 185 |
+
triton.Config({}, num_warps=8),
|
| 186 |
+
triton.Config({}, num_warps=16),
|
| 187 |
+
triton.Config({}, num_warps=32),
|
| 188 |
+
],
|
| 189 |
+
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
|
| 190 |
+
)
|
| 191 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 192 |
+
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
| 193 |
+
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
| 194 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
| 195 |
+
@triton.jit
|
| 196 |
+
def _layer_norm_bwd_kernel(
|
| 197 |
+
X, # pointer to the input
|
| 198 |
+
W, # pointer to the weights
|
| 199 |
+
B, # pointer to the biases
|
| 200 |
+
Y, # pointer to the output to be recomputed
|
| 201 |
+
DY, # pointer to the output gradient
|
| 202 |
+
DX, # pointer to the input gradient
|
| 203 |
+
DW, # pointer to the partial sum of weights gradient
|
| 204 |
+
DB, # pointer to the partial sum of biases gradient
|
| 205 |
+
DRESIDUAL,
|
| 206 |
+
DRESIDUAL_IN,
|
| 207 |
+
Mean, # pointer to the mean
|
| 208 |
+
Rstd, # pointer to the 1/std
|
| 209 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 210 |
+
stride_y_row,
|
| 211 |
+
stride_dy_row,
|
| 212 |
+
stride_dx_row,
|
| 213 |
+
stride_dres_row,
|
| 214 |
+
stride_dres_in_row,
|
| 215 |
+
M, # number of rows in X
|
| 216 |
+
N, # number of columns in X
|
| 217 |
+
eps, # epsilon to avoid division by zero
|
| 218 |
+
rows_per_program,
|
| 219 |
+
IS_RMS_NORM: tl.constexpr,
|
| 220 |
+
BLOCK_N: tl.constexpr,
|
| 221 |
+
HAS_DRESIDUAL: tl.constexpr,
|
| 222 |
+
STORE_DRESIDUAL: tl.constexpr,
|
| 223 |
+
HAS_BIAS: tl.constexpr,
|
| 224 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
| 225 |
+
):
|
| 226 |
+
# Map the program id to the elements of X, DX, and DY it should compute.
|
| 227 |
+
row_block_id = tl.program_id(0)
|
| 228 |
+
row_start = row_block_id * rows_per_program
|
| 229 |
+
cols = tl.arange(0, BLOCK_N)
|
| 230 |
+
mask = cols < N
|
| 231 |
+
X += row_start * stride_x_row
|
| 232 |
+
if HAS_DRESIDUAL:
|
| 233 |
+
DRESIDUAL += row_start * stride_dres_row
|
| 234 |
+
if STORE_DRESIDUAL:
|
| 235 |
+
DRESIDUAL_IN += row_start * stride_dres_in_row
|
| 236 |
+
DY += row_start * stride_dy_row
|
| 237 |
+
DX += row_start * stride_dx_row
|
| 238 |
+
if RECOMPUTE_OUTPUT:
|
| 239 |
+
Y += row_start * stride_y_row
|
| 240 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 241 |
+
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
| 242 |
+
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
| 243 |
+
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 244 |
+
if HAS_BIAS:
|
| 245 |
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 246 |
+
row_end = min((row_block_id + 1) * rows_per_program, M)
|
| 247 |
+
for row in range(row_start, row_end):
|
| 248 |
+
# Load data to SRAM
|
| 249 |
+
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
| 250 |
+
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
| 251 |
+
if not IS_RMS_NORM:
|
| 252 |
+
mean = tl.load(Mean + row)
|
| 253 |
+
rstd = tl.load(Rstd + row)
|
| 254 |
+
# Compute dx
|
| 255 |
+
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 256 |
+
xhat = tl.where(mask, xhat, 0.0)
|
| 257 |
+
if RECOMPUTE_OUTPUT:
|
| 258 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
| 259 |
+
tl.store(Y + cols, y, mask=mask)
|
| 260 |
+
wdy = w * dy
|
| 261 |
+
dw += dy * xhat
|
| 262 |
+
if HAS_BIAS:
|
| 263 |
+
db += dy
|
| 264 |
+
if not IS_RMS_NORM:
|
| 265 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 266 |
+
c2 = tl.sum(wdy, axis=0) / N
|
| 267 |
+
dx = (wdy - (xhat * c1 + c2)) * rstd
|
| 268 |
+
else:
|
| 269 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 270 |
+
dx = (wdy - xhat * c1) * rstd
|
| 271 |
+
if HAS_DRESIDUAL:
|
| 272 |
+
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
| 273 |
+
dx += dres
|
| 274 |
+
# Write dx
|
| 275 |
+
if STORE_DRESIDUAL:
|
| 276 |
+
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
| 277 |
+
tl.store(DX + cols, dx, mask=mask)
|
| 278 |
+
|
| 279 |
+
X += stride_x_row
|
| 280 |
+
if HAS_DRESIDUAL:
|
| 281 |
+
DRESIDUAL += stride_dres_row
|
| 282 |
+
if STORE_DRESIDUAL:
|
| 283 |
+
DRESIDUAL_IN += stride_dres_in_row
|
| 284 |
+
if RECOMPUTE_OUTPUT:
|
| 285 |
+
Y += stride_y_row
|
| 286 |
+
DY += stride_dy_row
|
| 287 |
+
DX += stride_dx_row
|
| 288 |
+
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
| 289 |
+
if HAS_BIAS:
|
| 290 |
+
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def _layer_norm_bwd(
|
| 294 |
+
dy,
|
| 295 |
+
x,
|
| 296 |
+
weight,
|
| 297 |
+
bias,
|
| 298 |
+
eps,
|
| 299 |
+
mean,
|
| 300 |
+
rstd,
|
| 301 |
+
dresidual=None,
|
| 302 |
+
has_residual=False,
|
| 303 |
+
is_rms_norm=False,
|
| 304 |
+
x_dtype=None,
|
| 305 |
+
recompute_output=False,
|
| 306 |
+
):
|
| 307 |
+
M, N = x.shape
|
| 308 |
+
assert x.stride(-1) == 1
|
| 309 |
+
assert dy.stride(-1) == 1
|
| 310 |
+
assert dy.shape == (M, N)
|
| 311 |
+
if dresidual is not None:
|
| 312 |
+
assert dresidual.stride(-1) == 1
|
| 313 |
+
assert dresidual.shape == (M, N)
|
| 314 |
+
assert weight.shape == (N,)
|
| 315 |
+
assert weight.stride(-1) == 1
|
| 316 |
+
if bias is not None:
|
| 317 |
+
assert bias.stride(-1) == 1
|
| 318 |
+
assert bias.shape == (N,)
|
| 319 |
+
# allocate output
|
| 320 |
+
dx = (
|
| 321 |
+
torch.empty_like(x)
|
| 322 |
+
if x_dtype is None
|
| 323 |
+
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
| 324 |
+
)
|
| 325 |
+
dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
|
| 326 |
+
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
|
| 327 |
+
|
| 328 |
+
# Less than 64KB per feature: enqueue fused kernel
|
| 329 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 330 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 331 |
+
if N > BLOCK_N:
|
| 332 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 333 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
| 334 |
+
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
| 335 |
+
_db = (
|
| 336 |
+
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
| 337 |
+
if bias is not None
|
| 338 |
+
else None
|
| 339 |
+
)
|
| 340 |
+
rows_per_program = math.ceil(M / sm_count)
|
| 341 |
+
grid = (sm_count,)
|
| 342 |
+
with torch.cuda.device(x.device.index):
|
| 343 |
+
_layer_norm_bwd_kernel[grid](
|
| 344 |
+
x,
|
| 345 |
+
weight,
|
| 346 |
+
bias,
|
| 347 |
+
y,
|
| 348 |
+
dy,
|
| 349 |
+
dx,
|
| 350 |
+
_dw,
|
| 351 |
+
_db,
|
| 352 |
+
dresidual,
|
| 353 |
+
dresidual_in,
|
| 354 |
+
mean,
|
| 355 |
+
rstd,
|
| 356 |
+
x.stride(0),
|
| 357 |
+
0 if not recompute_output else y.stride(0),
|
| 358 |
+
dy.stride(0),
|
| 359 |
+
dx.stride(0),
|
| 360 |
+
dresidual.stride(0) if dresidual is not None else 0,
|
| 361 |
+
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
| 362 |
+
M,
|
| 363 |
+
N,
|
| 364 |
+
eps,
|
| 365 |
+
rows_per_program,
|
| 366 |
+
is_rms_norm,
|
| 367 |
+
BLOCK_N,
|
| 368 |
+
dresidual is not None,
|
| 369 |
+
dresidual_in is not None,
|
| 370 |
+
bias is not None,
|
| 371 |
+
)
|
| 372 |
+
dw = _dw.sum(0).to(weight.dtype)
|
| 373 |
+
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
| 374 |
+
# Don't need to compute dresidual_in separately in this case
|
| 375 |
+
if has_residual and dx.dtype == x.dtype:
|
| 376 |
+
dresidual_in = dx
|
| 377 |
+
return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class LayerNormFn(torch.autograd.Function):
|
| 381 |
+
@staticmethod
|
| 382 |
+
def forward(
|
| 383 |
+
ctx,
|
| 384 |
+
x,
|
| 385 |
+
weight,
|
| 386 |
+
bias,
|
| 387 |
+
residual=None,
|
| 388 |
+
eps=1e-6,
|
| 389 |
+
prenorm=False,
|
| 390 |
+
residual_in_fp32=False,
|
| 391 |
+
is_rms_norm=False,
|
| 392 |
+
):
|
| 393 |
+
x_shape_og = x.shape
|
| 394 |
+
# reshape input data into 2D tensor
|
| 395 |
+
x = x.reshape(-1, x.shape[-1])
|
| 396 |
+
if x.stride(-1) != 1:
|
| 397 |
+
x = x.contiguous()
|
| 398 |
+
if residual is not None:
|
| 399 |
+
assert residual.shape == x_shape_og
|
| 400 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
| 401 |
+
if residual.stride(-1) != 1:
|
| 402 |
+
residual = residual.contiguous()
|
| 403 |
+
weight = weight.contiguous()
|
| 404 |
+
if bias is not None:
|
| 405 |
+
bias = bias.contiguous()
|
| 406 |
+
residual_dtype = (
|
| 407 |
+
residual.dtype
|
| 408 |
+
if residual is not None
|
| 409 |
+
else (torch.float32 if residual_in_fp32 else None)
|
| 410 |
+
)
|
| 411 |
+
y, mean, rstd, residual_out = _layer_norm_fwd(
|
| 412 |
+
x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
|
| 413 |
+
)
|
| 414 |
+
ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
|
| 415 |
+
ctx.x_shape_og = x_shape_og
|
| 416 |
+
ctx.eps = eps
|
| 417 |
+
ctx.is_rms_norm = is_rms_norm
|
| 418 |
+
ctx.has_residual = residual is not None
|
| 419 |
+
ctx.prenorm = prenorm
|
| 420 |
+
ctx.x_dtype = x.dtype
|
| 421 |
+
y = y.reshape(x_shape_og)
|
| 422 |
+
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
|
| 423 |
+
|
| 424 |
+
@staticmethod
|
| 425 |
+
def backward(ctx, dy, *args):
|
| 426 |
+
x, weight, bias, mean, rstd = ctx.saved_tensors
|
| 427 |
+
dy = dy.reshape(-1, dy.shape[-1])
|
| 428 |
+
if dy.stride(-1) != 1:
|
| 429 |
+
dy = dy.contiguous()
|
| 430 |
+
assert dy.shape == x.shape
|
| 431 |
+
if ctx.prenorm:
|
| 432 |
+
dresidual = args[0]
|
| 433 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 434 |
+
if dresidual.stride(-1) != 1:
|
| 435 |
+
dresidual = dresidual.contiguous()
|
| 436 |
+
assert dresidual.shape == x.shape
|
| 437 |
+
else:
|
| 438 |
+
dresidual = None
|
| 439 |
+
dx, dw, db, dresidual_in = _layer_norm_bwd(
|
| 440 |
+
dy,
|
| 441 |
+
x,
|
| 442 |
+
weight,
|
| 443 |
+
bias,
|
| 444 |
+
ctx.eps,
|
| 445 |
+
mean,
|
| 446 |
+
rstd,
|
| 447 |
+
dresidual,
|
| 448 |
+
ctx.has_residual,
|
| 449 |
+
ctx.is_rms_norm,
|
| 450 |
+
x_dtype=ctx.x_dtype,
|
| 451 |
+
)
|
| 452 |
+
return (
|
| 453 |
+
dx.reshape(ctx.x_shape_og),
|
| 454 |
+
dw,
|
| 455 |
+
db,
|
| 456 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 457 |
+
None,
|
| 458 |
+
None,
|
| 459 |
+
None,
|
| 460 |
+
None,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def layer_norm_fn(
|
| 465 |
+
x,
|
| 466 |
+
weight,
|
| 467 |
+
bias,
|
| 468 |
+
residual=None,
|
| 469 |
+
eps=1e-6,
|
| 470 |
+
prenorm=False,
|
| 471 |
+
residual_in_fp32=False,
|
| 472 |
+
is_rms_norm=False,
|
| 473 |
+
):
|
| 474 |
+
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
|
| 478 |
+
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class RMSNorm(torch.nn.Module):
|
| 482 |
+
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
| 483 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 484 |
+
super().__init__()
|
| 485 |
+
self.eps = eps
|
| 486 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
| 487 |
+
self.register_parameter("bias", None)
|
| 488 |
+
self.reset_parameters()
|
| 489 |
+
|
| 490 |
+
def reset_parameters(self):
|
| 491 |
+
torch.nn.init.ones_(self.weight)
|
| 492 |
+
|
| 493 |
+
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
| 494 |
+
return rms_norm_fn(
|
| 495 |
+
x,
|
| 496 |
+
self.weight,
|
| 497 |
+
self.bias,
|
| 498 |
+
residual=residual,
|
| 499 |
+
eps=self.eps,
|
| 500 |
+
prenorm=prenorm,
|
| 501 |
+
residual_in_fp32=residual_in_fp32,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
class LayerNormLinearFn(torch.autograd.Function):
|
| 506 |
+
@staticmethod
|
| 507 |
+
@custom_fwd
|
| 508 |
+
def forward(
|
| 509 |
+
ctx,
|
| 510 |
+
x,
|
| 511 |
+
norm_weight,
|
| 512 |
+
norm_bias,
|
| 513 |
+
linear_weight,
|
| 514 |
+
linear_bias,
|
| 515 |
+
residual=None,
|
| 516 |
+
eps=1e-6,
|
| 517 |
+
prenorm=False,
|
| 518 |
+
residual_in_fp32=False,
|
| 519 |
+
is_rms_norm=False,
|
| 520 |
+
):
|
| 521 |
+
x_shape_og = x.shape
|
| 522 |
+
# reshape input data into 2D tensor
|
| 523 |
+
x = x.reshape(-1, x.shape[-1])
|
| 524 |
+
if x.stride(-1) != 1:
|
| 525 |
+
x = x.contiguous()
|
| 526 |
+
if residual is not None:
|
| 527 |
+
assert residual.shape == x_shape_og
|
| 528 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
| 529 |
+
if residual.stride(-1) != 1:
|
| 530 |
+
residual = residual.contiguous()
|
| 531 |
+
norm_weight = norm_weight.contiguous()
|
| 532 |
+
if norm_bias is not None:
|
| 533 |
+
norm_bias = norm_bias.contiguous()
|
| 534 |
+
residual_dtype = (
|
| 535 |
+
residual.dtype
|
| 536 |
+
if residual is not None
|
| 537 |
+
else (torch.float32 if residual_in_fp32 else None)
|
| 538 |
+
)
|
| 539 |
+
y, mean, rstd, residual_out = _layer_norm_fwd(
|
| 540 |
+
x,
|
| 541 |
+
norm_weight,
|
| 542 |
+
norm_bias,
|
| 543 |
+
eps,
|
| 544 |
+
residual,
|
| 545 |
+
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
|
| 546 |
+
residual_dtype=residual_dtype,
|
| 547 |
+
is_rms_norm=is_rms_norm,
|
| 548 |
+
)
|
| 549 |
+
y = y.reshape(x_shape_og)
|
| 550 |
+
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
| 551 |
+
linear_weight = linear_weight.to(dtype)
|
| 552 |
+
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
| 553 |
+
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
| 554 |
+
# We don't store y, will be recomputed in the backward pass to save memory
|
| 555 |
+
ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
|
| 556 |
+
ctx.x_shape_og = x_shape_og
|
| 557 |
+
ctx.eps = eps
|
| 558 |
+
ctx.is_rms_norm = is_rms_norm
|
| 559 |
+
ctx.has_residual = residual is not None
|
| 560 |
+
ctx.prenorm = prenorm
|
| 561 |
+
ctx.x_dtype = x.dtype
|
| 562 |
+
ctx.linear_bias_is_none = linear_bias is None
|
| 563 |
+
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
| 564 |
+
|
| 565 |
+
@staticmethod
|
| 566 |
+
@custom_bwd
|
| 567 |
+
def backward(ctx, dout, *args):
|
| 568 |
+
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
| 569 |
+
dout = dout.reshape(-1, dout.shape[-1])
|
| 570 |
+
dy = F.linear(dout, linear_weight.t())
|
| 571 |
+
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
| 572 |
+
if dy.stride(-1) != 1:
|
| 573 |
+
dy = dy.contiguous()
|
| 574 |
+
assert dy.shape == x.shape
|
| 575 |
+
if ctx.prenorm:
|
| 576 |
+
dresidual = args[0]
|
| 577 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 578 |
+
if dresidual.stride(-1) != 1:
|
| 579 |
+
dresidual = dresidual.contiguous()
|
| 580 |
+
assert dresidual.shape == x.shape
|
| 581 |
+
else:
|
| 582 |
+
dresidual = None
|
| 583 |
+
dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
|
| 584 |
+
dy,
|
| 585 |
+
x,
|
| 586 |
+
norm_weight,
|
| 587 |
+
norm_bias,
|
| 588 |
+
ctx.eps,
|
| 589 |
+
mean,
|
| 590 |
+
rstd,
|
| 591 |
+
dresidual,
|
| 592 |
+
ctx.has_residual,
|
| 593 |
+
ctx.is_rms_norm,
|
| 594 |
+
x_dtype=ctx.x_dtype,
|
| 595 |
+
recompute_output=True,
|
| 596 |
+
)
|
| 597 |
+
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
| 598 |
+
return (
|
| 599 |
+
dx.reshape(ctx.x_shape_og),
|
| 600 |
+
dnorm_weight,
|
| 601 |
+
dnorm_bias,
|
| 602 |
+
dlinear_weight,
|
| 603 |
+
dlinear_bias,
|
| 604 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 605 |
+
None,
|
| 606 |
+
None,
|
| 607 |
+
None,
|
| 608 |
+
None,
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def layer_norm_linear_fn(
|
| 613 |
+
x,
|
| 614 |
+
norm_weight,
|
| 615 |
+
norm_bias,
|
| 616 |
+
linear_weight,
|
| 617 |
+
linear_bias,
|
| 618 |
+
residual=None,
|
| 619 |
+
eps=1e-6,
|
| 620 |
+
prenorm=False,
|
| 621 |
+
residual_in_fp32=False,
|
| 622 |
+
is_rms_norm=False,
|
| 623 |
+
):
|
| 624 |
+
return LayerNormLinearFn.apply(
|
| 625 |
+
x,
|
| 626 |
+
norm_weight,
|
| 627 |
+
norm_bias,
|
| 628 |
+
linear_weight,
|
| 629 |
+
linear_bias,
|
| 630 |
+
residual,
|
| 631 |
+
eps,
|
| 632 |
+
prenorm,
|
| 633 |
+
residual_in_fp32,
|
| 634 |
+
is_rms_norm,
|
| 635 |
+
)
|
mamba_ssm/ops/triton/selective_state_update.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
| 2 |
+
|
| 3 |
+
"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
import triton
|
| 11 |
+
import triton.language as tl
|
| 12 |
+
|
| 13 |
+
from einops import rearrange, repeat
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
| 17 |
+
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
| 18 |
+
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
| 19 |
+
@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
|
| 20 |
+
@triton.jit
|
| 21 |
+
def _selective_scan_update_kernel(
|
| 22 |
+
# Pointers to matrices
|
| 23 |
+
state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
|
| 24 |
+
# Matrix dimensions
|
| 25 |
+
batch, nheads, dim, dstate, nheads_ngroups_ratio,
|
| 26 |
+
# Strides
|
| 27 |
+
stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,
|
| 28 |
+
stride_x_batch, stride_x_head, stride_x_dim,
|
| 29 |
+
stride_dt_batch, stride_dt_head, stride_dt_dim,
|
| 30 |
+
stride_dt_bias_head, stride_dt_bias_dim,
|
| 31 |
+
stride_A_head, stride_A_dim, stride_A_dstate,
|
| 32 |
+
stride_B_batch, stride_B_group, stride_B_dstate,
|
| 33 |
+
stride_C_batch, stride_C_group, stride_C_dstate,
|
| 34 |
+
stride_D_head, stride_D_dim,
|
| 35 |
+
stride_z_batch, stride_z_head, stride_z_dim,
|
| 36 |
+
stride_out_batch, stride_out_head, stride_out_dim,
|
| 37 |
+
# Meta-parameters
|
| 38 |
+
DT_SOFTPLUS: tl.constexpr,
|
| 39 |
+
TIE_HDIM: tl.constexpr,
|
| 40 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 41 |
+
HAS_DT_BIAS: tl.constexpr,
|
| 42 |
+
HAS_D: tl.constexpr,
|
| 43 |
+
HAS_Z: tl.constexpr,
|
| 44 |
+
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 45 |
+
):
|
| 46 |
+
pid_m = tl.program_id(axis=0)
|
| 47 |
+
pid_b = tl.program_id(axis=1)
|
| 48 |
+
pid_h = tl.program_id(axis=2)
|
| 49 |
+
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
| 50 |
+
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
| 51 |
+
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
| 52 |
+
if HAS_DT_BIAS:
|
| 53 |
+
dt_bias_ptr += pid_h * stride_dt_bias_head
|
| 54 |
+
A_ptr += pid_h * stride_A_head
|
| 55 |
+
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
| 56 |
+
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
| 57 |
+
if HAS_Z:
|
| 58 |
+
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
| 59 |
+
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
| 60 |
+
|
| 61 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 62 |
+
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
| 63 |
+
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
|
| 64 |
+
x_ptrs = x_ptr + offs_m * stride_x_dim
|
| 65 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
| 66 |
+
if HAS_DT_BIAS:
|
| 67 |
+
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
| 68 |
+
if HAS_D:
|
| 69 |
+
D_ptr += pid_h * stride_D_head
|
| 70 |
+
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
|
| 71 |
+
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
| 72 |
+
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
| 73 |
+
if HAS_D:
|
| 74 |
+
D_ptrs = D_ptr + offs_m * stride_D_dim
|
| 75 |
+
if HAS_Z:
|
| 76 |
+
z_ptrs = z_ptr + offs_m * stride_z_dim
|
| 77 |
+
out_ptrs = out_ptr + offs_m * stride_out_dim
|
| 78 |
+
|
| 79 |
+
state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
|
| 80 |
+
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 81 |
+
if not TIE_HDIM:
|
| 82 |
+
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 83 |
+
if HAS_DT_BIAS:
|
| 84 |
+
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 85 |
+
if DT_SOFTPLUS:
|
| 86 |
+
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
|
| 87 |
+
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
|
| 88 |
+
dA = tl.exp(A * dt[:, None])
|
| 89 |
+
else:
|
| 90 |
+
dt = tl.load(dt_ptr).to(tl.float32)
|
| 91 |
+
if HAS_DT_BIAS:
|
| 92 |
+
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
| 93 |
+
if DT_SOFTPLUS:
|
| 94 |
+
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
|
| 95 |
+
A = tl.load(A_ptr).to(tl.float32)
|
| 96 |
+
dA = tl.exp(A * dt) # scalar, not a matrix
|
| 97 |
+
|
| 98 |
+
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
| 99 |
+
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
| 100 |
+
if HAS_D:
|
| 101 |
+
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 102 |
+
if HAS_Z:
|
| 103 |
+
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 104 |
+
|
| 105 |
+
if not TIE_HDIM:
|
| 106 |
+
dB = B[None, :] * dt[:, None]
|
| 107 |
+
else:
|
| 108 |
+
dB = B * dt # vector of size (dstate,)
|
| 109 |
+
state = state * dA + dB * x[:, None]
|
| 110 |
+
tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
|
| 111 |
+
out = tl.sum(state * C[None, :], axis=1)
|
| 112 |
+
if HAS_D:
|
| 113 |
+
out += x * D
|
| 114 |
+
if HAS_Z:
|
| 115 |
+
out *= z * tl.sigmoid(z)
|
| 116 |
+
tl.store(out_ptrs, out, mask=offs_m < dim)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
|
| 120 |
+
"""
|
| 121 |
+
Argument:
|
| 122 |
+
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
| 123 |
+
x: (batch, dim) or (batch, nheads, dim)
|
| 124 |
+
dt: (batch, dim) or (batch, nheads, dim)
|
| 125 |
+
A: (dim, dstate) or (nheads, dim, dstate)
|
| 126 |
+
B: (batch, dstate) or (batch, ngroups, dstate)
|
| 127 |
+
C: (batch, dstate) or (batch, ngroups, dstate)
|
| 128 |
+
D: (dim,) or (nheads, dim)
|
| 129 |
+
z: (batch, dim) or (batch, nheads, dim)
|
| 130 |
+
dt_bias: (dim,) or (nheads, dim)
|
| 131 |
+
Return:
|
| 132 |
+
out: (batch, dim) or (batch, nheads, dim)
|
| 133 |
+
"""
|
| 134 |
+
has_heads = state.dim() > 3
|
| 135 |
+
if state.dim() == 3:
|
| 136 |
+
state = state.unsqueeze(1)
|
| 137 |
+
if x.dim() == 2:
|
| 138 |
+
x = x.unsqueeze(1)
|
| 139 |
+
if dt.dim() == 2:
|
| 140 |
+
dt = dt.unsqueeze(1)
|
| 141 |
+
if A.dim() == 2:
|
| 142 |
+
A = A.unsqueeze(0)
|
| 143 |
+
if B.dim() == 2:
|
| 144 |
+
B = B.unsqueeze(1)
|
| 145 |
+
if C.dim() == 2:
|
| 146 |
+
C = C.unsqueeze(1)
|
| 147 |
+
if D is not None and D.dim() == 1:
|
| 148 |
+
D = D.unsqueeze(0)
|
| 149 |
+
if z is not None and z.dim() == 2:
|
| 150 |
+
z = z.unsqueeze(1)
|
| 151 |
+
if dt_bias is not None and dt_bias.dim() == 1:
|
| 152 |
+
dt_bias = dt_bias.unsqueeze(0)
|
| 153 |
+
batch, nheads, dim, dstate = state.shape
|
| 154 |
+
assert x.shape == (batch, nheads, dim)
|
| 155 |
+
assert dt.shape == x.shape
|
| 156 |
+
assert A.shape == (nheads, dim, dstate)
|
| 157 |
+
ngroups = B.shape[1]
|
| 158 |
+
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
| 159 |
+
assert B.shape == (batch, ngroups, dstate)
|
| 160 |
+
assert C.shape == B.shape
|
| 161 |
+
if D is not None:
|
| 162 |
+
assert D.shape == (nheads, dim)
|
| 163 |
+
if z is not None:
|
| 164 |
+
assert z.shape == x.shape
|
| 165 |
+
if dt_bias is not None:
|
| 166 |
+
assert dt_bias.shape == (nheads, dim)
|
| 167 |
+
out = torch.empty_like(x)
|
| 168 |
+
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
| 169 |
+
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))
|
| 170 |
+
# We don't want autotune since it will overwrite the state
|
| 171 |
+
# We instead tune by hand.
|
| 172 |
+
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
|
| 173 |
+
else ((16, 4) if dstate <= 32 else
|
| 174 |
+
((8, 4) if dstate <= 64 else
|
| 175 |
+
((4, 4) if dstate <= 128 else
|
| 176 |
+
((4, 8))))))
|
| 177 |
+
tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0
|
| 178 |
+
with torch.cuda.device(x.device.index):
|
| 179 |
+
_selective_scan_update_kernel[grid](
|
| 180 |
+
state, x, dt, dt_bias, A, B, C, D, z, out,
|
| 181 |
+
batch, nheads, dim, dstate, nheads // ngroups,
|
| 182 |
+
state.stride(0), state.stride(1), state.stride(2), state.stride(3),
|
| 183 |
+
x.stride(0), x.stride(1), x.stride(2),
|
| 184 |
+
dt.stride(0), dt.stride(1), dt.stride(2),
|
| 185 |
+
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
|
| 186 |
+
A.stride(0), A.stride(1), A.stride(2),
|
| 187 |
+
B.stride(0), B.stride(1), B.stride(2),
|
| 188 |
+
C.stride(0), C.stride(1), C.stride(2),
|
| 189 |
+
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
| 190 |
+
z_strides[0], z_strides[1], z_strides[2],
|
| 191 |
+
out.stride(0), out.stride(1), out.stride(2),
|
| 192 |
+
dt_softplus,
|
| 193 |
+
tie_hdim,
|
| 194 |
+
BLOCK_SIZE_M,
|
| 195 |
+
num_warps=num_warps,
|
| 196 |
+
)
|
| 197 |
+
if not has_heads:
|
| 198 |
+
out = out.squeeze(1)
|
| 199 |
+
return out
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
|
| 203 |
+
"""
|
| 204 |
+
Argument:
|
| 205 |
+
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
| 206 |
+
x: (batch, dim) or (batch, nheads, dim)
|
| 207 |
+
dt: (batch, dim) or (batch, nheads, dim)
|
| 208 |
+
A: (dim, dstate) or (nheads, dim, dstate)
|
| 209 |
+
B: (batch, dstate) or (batch, ngroups, dstate)
|
| 210 |
+
C: (batch, dstate) or (batch, ngroups, dstate)
|
| 211 |
+
D: (dim,) or (nheads, dim)
|
| 212 |
+
z: (batch, dim) or (batch, nheads, dim)
|
| 213 |
+
dt_bias: (dim,) or (nheads, dim)
|
| 214 |
+
Return:
|
| 215 |
+
out: (batch, dim) or (batch, nheads, dim)
|
| 216 |
+
"""
|
| 217 |
+
has_heads = state.dim() > 3
|
| 218 |
+
if state.dim() == 3:
|
| 219 |
+
state = state.unsqueeze(1)
|
| 220 |
+
if x.dim() == 2:
|
| 221 |
+
x = x.unsqueeze(1)
|
| 222 |
+
if dt.dim() == 2:
|
| 223 |
+
dt = dt.unsqueeze(1)
|
| 224 |
+
if A.dim() == 2:
|
| 225 |
+
A = A.unsqueeze(0)
|
| 226 |
+
if B.dim() == 2:
|
| 227 |
+
B = B.unsqueeze(1)
|
| 228 |
+
if C.dim() == 2:
|
| 229 |
+
C = C.unsqueeze(1)
|
| 230 |
+
if D is not None and D.dim() == 1:
|
| 231 |
+
D = D.unsqueeze(0)
|
| 232 |
+
if z is not None and z.dim() == 2:
|
| 233 |
+
z = z.unsqueeze(1)
|
| 234 |
+
if dt_bias is not None and dt_bias.dim() == 1:
|
| 235 |
+
dt_bias = dt_bias.unsqueeze(0)
|
| 236 |
+
batch, nheads, dim, dstate = state.shape
|
| 237 |
+
assert x.shape == (batch, nheads, dim)
|
| 238 |
+
assert dt.shape == x.shape
|
| 239 |
+
assert A.shape == (nheads, dim, dstate)
|
| 240 |
+
ngroups = B.shape[1]
|
| 241 |
+
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
| 242 |
+
assert B.shape == (batch, ngroups, dstate)
|
| 243 |
+
assert C.shape == B.shape
|
| 244 |
+
if D is not None:
|
| 245 |
+
assert D.shape == (nheads, dim)
|
| 246 |
+
if z is not None:
|
| 247 |
+
assert z.shape == x.shape
|
| 248 |
+
if dt_bias is not None:
|
| 249 |
+
assert dt_bias.shape == (nheads, dim)
|
| 250 |
+
dt = dt + dt_bias
|
| 251 |
+
dt = F.softplus(dt) if dt_softplus else dt
|
| 252 |
+
dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate)
|
| 253 |
+
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
| 254 |
+
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
| 255 |
+
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
|
| 256 |
+
state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
|
| 257 |
+
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
| 258 |
+
if D is not None:
|
| 259 |
+
out += (x * D).to(out.dtype)
|
| 260 |
+
out = (out if z is None else out * F.silu(z)).to(x.dtype)
|
| 261 |
+
if not has_heads:
|
| 262 |
+
out = out.squeeze(1)
|
| 263 |
+
return out
|
mamba_ssm/utils/__init__.py
ADDED
|
File without changes
|
mamba_ssm/utils/generation.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
| 2 |
+
import gc
|
| 3 |
+
import time
|
| 4 |
+
from collections import namedtuple
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from functools import partial
|
| 7 |
+
from typing import Callable, Optional, Sequence, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from einops import rearrange, repeat
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from torch.profiler import ProfilerActivity, profile, record_function
|
| 14 |
+
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class InferenceParams:
|
| 19 |
+
"""Inference parameters that are passed to the main model in order
|
| 20 |
+
to efficienly calculate and store the context during inference."""
|
| 21 |
+
|
| 22 |
+
max_seqlen: int
|
| 23 |
+
max_batch_size: int
|
| 24 |
+
seqlen_offset: int = 0
|
| 25 |
+
batch_size_offset: int = 0
|
| 26 |
+
key_value_memory_dict: dict = field(default_factory=dict)
|
| 27 |
+
lengths_per_sample: Optional[Tensor] = None
|
| 28 |
+
|
| 29 |
+
def reset(self, max_seqlen, max_batch_size):
|
| 30 |
+
self.max_seqlen = max_seqlen
|
| 31 |
+
self.max_batch_size = max_batch_size
|
| 32 |
+
self.seqlen_offset = 0
|
| 33 |
+
if self.lengths_per_sample is not None:
|
| 34 |
+
self.lengths_per_sample.zero_()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def modify_logits_for_min_p_filtering(logits, min_p):
|
| 38 |
+
"""Set the logits for none min_p values to -inf. Done in-place."""
|
| 39 |
+
if min_p <= 0.0 or min_p >= 1.0:
|
| 40 |
+
return
|
| 41 |
+
indices_to_remove = logits < min_p
|
| 42 |
+
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
| 43 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
| 44 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
|
| 45 |
+
def modify_logits_for_top_k_filtering(logits, top_k):
|
| 46 |
+
"""Set the logits for none top-k values to -inf. Done in-place."""
|
| 47 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 48 |
+
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
| 52 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
| 53 |
+
def modify_logits_for_top_p_filtering(logits, top_p):
|
| 54 |
+
"""Set the logits for none top-p values to -inf. Done in-place."""
|
| 55 |
+
if top_p <= 0.0 or top_p >= 1.0:
|
| 56 |
+
return
|
| 57 |
+
# First sort and calculate cumulative sum of probabilities.
|
| 58 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
| 59 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
| 60 |
+
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
| 61 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
| 62 |
+
# scatter sorted tensors to original indexing
|
| 63 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 64 |
+
1, sorted_indices, sorted_indices_to_remove
|
| 65 |
+
)
|
| 66 |
+
logits.masked_fill_(indices_to_remove, float("-inf"))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
|
| 70 |
+
"""Apply repetition penalty. See https://arxiv.org/abs/1909.05858
|
| 71 |
+
logits: (batch_size, vocab_size)
|
| 72 |
+
prev_output_tokens: (batch_size, seq_len)
|
| 73 |
+
"""
|
| 74 |
+
if repetition_penalty == 1.0:
|
| 75 |
+
return logits
|
| 76 |
+
score = torch.gather(logits, 1, prev_output_tokens)
|
| 77 |
+
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
| 78 |
+
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
| 79 |
+
logits.scatter_(1, prev_output_tokens, score)
|
| 80 |
+
return logits
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
|
| 84 |
+
"""Sample from top-k logits.
|
| 85 |
+
Arguments:
|
| 86 |
+
logits: Tensor of shape (batch_size, vocab_size)
|
| 87 |
+
"""
|
| 88 |
+
if top_k == 1: # Short-circuit for greedy decoding
|
| 89 |
+
return logits.argmax(dim=-1)
|
| 90 |
+
else:
|
| 91 |
+
if top_p > 0.0:
|
| 92 |
+
assert top_p <= 1.0, "top-p should be in (0, 1]."
|
| 93 |
+
if top_k > 0:
|
| 94 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
| 95 |
+
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
| 96 |
+
if temperature != 1.0:
|
| 97 |
+
logits_top /= temperature
|
| 98 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
| 99 |
+
return indices[
|
| 100 |
+
torch.arange(indices.shape[0], device=indices.device),
|
| 101 |
+
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
|
| 102 |
+
]
|
| 103 |
+
else:
|
| 104 |
+
if min_p > 0.0:
|
| 105 |
+
logits_top = logits.clone()
|
| 106 |
+
max_prob = logits_top[..., 0].item()
|
| 107 |
+
min_prob = max_prob * min_p
|
| 108 |
+
modify_logits_for_min_p_filtering(logits_top, min_p)
|
| 109 |
+
if temperature != 1.0:
|
| 110 |
+
logits_top /= temperature
|
| 111 |
+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
| 112 |
+
# Clone so that when we modify for top_p we don't change the original logits
|
| 113 |
+
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
|
| 114 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
| 115 |
+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
|
| 116 |
+
dim=-1
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@torch.inference_mode()
|
| 121 |
+
def decode(
|
| 122 |
+
input_ids,
|
| 123 |
+
model,
|
| 124 |
+
max_length,
|
| 125 |
+
top_k=1,
|
| 126 |
+
top_p=0.0,
|
| 127 |
+
min_p=0.0,
|
| 128 |
+
temperature=1.0,
|
| 129 |
+
repetition_penalty=1.0,
|
| 130 |
+
eos_token_id=None,
|
| 131 |
+
teacher_outputs=None,
|
| 132 |
+
vocab_size=None,
|
| 133 |
+
cg=False,
|
| 134 |
+
enable_timing=False,
|
| 135 |
+
streamer: Optional[TextStreamer] = None
|
| 136 |
+
):
|
| 137 |
+
"""Decoding, either greedy or with top-k or top-p sampling.
|
| 138 |
+
If top-k = 0, don't limit the number of candidates (pure sampling).
|
| 139 |
+
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
| 140 |
+
then top-p.
|
| 141 |
+
We assume that all sequences in the same batch have the same length.
|
| 142 |
+
|
| 143 |
+
Arguments:
|
| 144 |
+
input_ids: (batch, seq_len)
|
| 145 |
+
max_length: int
|
| 146 |
+
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
| 147 |
+
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
| 148 |
+
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
| 149 |
+
sequences: (batch, max_length)
|
| 150 |
+
scores: tuples of (batch, vocab_size)
|
| 151 |
+
"""
|
| 152 |
+
if streamer is not None:
|
| 153 |
+
streamer.put(input_ids.cpu())
|
| 154 |
+
|
| 155 |
+
batch_size, seqlen_og = input_ids.shape
|
| 156 |
+
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
|
| 157 |
+
if cg:
|
| 158 |
+
if not hasattr(model, "_decoding_cache"):
|
| 159 |
+
model._decoding_cache = None
|
| 160 |
+
model._decoding_cache = update_graph_cache(
|
| 161 |
+
model,
|
| 162 |
+
model._decoding_cache,
|
| 163 |
+
batch_size,
|
| 164 |
+
seqlen_og,
|
| 165 |
+
max_length,
|
| 166 |
+
)
|
| 167 |
+
inference_params = model._decoding_cache.inference_params
|
| 168 |
+
inference_params.reset(max_length, batch_size)
|
| 169 |
+
else:
|
| 170 |
+
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
| 171 |
+
|
| 172 |
+
def get_logits(input_ids, inference_params):
|
| 173 |
+
decoding = inference_params.seqlen_offset > 0
|
| 174 |
+
if decoding:
|
| 175 |
+
position_ids = torch.full(
|
| 176 |
+
(batch_size, 1),
|
| 177 |
+
inference_params.seqlen_offset,
|
| 178 |
+
dtype=torch.long,
|
| 179 |
+
device=input_ids.device,
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
position_ids = None
|
| 183 |
+
if not cg or not decoding:
|
| 184 |
+
logits = model(
|
| 185 |
+
input_ids,
|
| 186 |
+
position_ids=position_ids,
|
| 187 |
+
inference_params=inference_params,
|
| 188 |
+
num_last_tokens=1,
|
| 189 |
+
).logits.squeeze(dim=1)
|
| 190 |
+
else:
|
| 191 |
+
logits = model._decoding_cache.run(
|
| 192 |
+
input_ids, position_ids, inference_params.seqlen_offset
|
| 193 |
+
).squeeze(dim=1)
|
| 194 |
+
return logits[..., :vocab_size] if vocab_size is not None else logits
|
| 195 |
+
|
| 196 |
+
def sample_tokens(logits, inference_params):
|
| 197 |
+
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
|
| 198 |
+
token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
|
| 199 |
+
else:
|
| 200 |
+
token = teacher_outputs[:, inference_params.seqlen_offset]
|
| 201 |
+
# return rearrange(token, "b -> b 1")
|
| 202 |
+
return token.unsqueeze(1)
|
| 203 |
+
|
| 204 |
+
def should_stop(current_token, inference_params):
|
| 205 |
+
if inference_params.seqlen_offset == 0:
|
| 206 |
+
return False
|
| 207 |
+
if eos_token_id is not None and (current_token == eos_token_id).all():
|
| 208 |
+
return True
|
| 209 |
+
if inference_params.seqlen_offset >= max_length - 1:
|
| 210 |
+
return True
|
| 211 |
+
return False
|
| 212 |
+
|
| 213 |
+
start = torch.cuda.Event(enable_timing=enable_timing)
|
| 214 |
+
end = torch.cuda.Event(enable_timing=enable_timing)
|
| 215 |
+
|
| 216 |
+
if enable_timing:
|
| 217 |
+
start.record()
|
| 218 |
+
scores, sequences = [], [input_ids]
|
| 219 |
+
sequences_cat = input_ids
|
| 220 |
+
while not should_stop(sequences[-1], inference_params):
|
| 221 |
+
scores.append(get_logits(sequences[-1], inference_params))
|
| 222 |
+
inference_params.seqlen_offset += sequences[-1].shape[1]
|
| 223 |
+
if repetition_penalty == 1.0:
|
| 224 |
+
sampled_tokens = sample_tokens(scores[-1], inference_params)
|
| 225 |
+
else:
|
| 226 |
+
logits = modify_logit_for_repetition_penalty(
|
| 227 |
+
scores[-1].clone(), sequences_cat, repetition_penalty
|
| 228 |
+
)
|
| 229 |
+
sampled_tokens = sample_tokens(logits, inference_params)
|
| 230 |
+
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
|
| 231 |
+
sequences.append(sampled_tokens)
|
| 232 |
+
if streamer is not None:
|
| 233 |
+
streamer.put(sampled_tokens.cpu())
|
| 234 |
+
if streamer is not None:
|
| 235 |
+
streamer.end()
|
| 236 |
+
if enable_timing:
|
| 237 |
+
end.record()
|
| 238 |
+
torch.cuda.synchronize()
|
| 239 |
+
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
|
| 240 |
+
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
| 241 |
+
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class GenerationMixin:
|
| 245 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 246 |
+
raise NotImplementedError
|
| 247 |
+
|
| 248 |
+
def generate(
|
| 249 |
+
self,
|
| 250 |
+
input_ids,
|
| 251 |
+
max_length,
|
| 252 |
+
top_k=1,
|
| 253 |
+
top_p=0.0,
|
| 254 |
+
min_p=0.0,
|
| 255 |
+
temperature=1.0,
|
| 256 |
+
return_dict_in_generate=False,
|
| 257 |
+
output_scores=False,
|
| 258 |
+
**kwargs,
|
| 259 |
+
):
|
| 260 |
+
output = decode(
|
| 261 |
+
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
|
| 262 |
+
)
|
| 263 |
+
if not output_scores:
|
| 264 |
+
output.scores = None
|
| 265 |
+
return output if return_dict_in_generate else output.sequences
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@dataclass
|
| 269 |
+
class DecodingCGCache:
|
| 270 |
+
max_batch_size: int = 0
|
| 271 |
+
max_seqlen: int = 0
|
| 272 |
+
device = None
|
| 273 |
+
dtype = None
|
| 274 |
+
callables: dict = field(default_factory=dict)
|
| 275 |
+
mempool = None
|
| 276 |
+
inference_params: Optional[InferenceParams] = None
|
| 277 |
+
run: Optional[Callable] = None
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@torch.inference_mode()
|
| 281 |
+
def update_graph_cache(
|
| 282 |
+
model,
|
| 283 |
+
cache,
|
| 284 |
+
batch_size,
|
| 285 |
+
seqlen_og,
|
| 286 |
+
max_seqlen,
|
| 287 |
+
decoding_seqlens=(1,),
|
| 288 |
+
dtype=None,
|
| 289 |
+
n_warmups=2,
|
| 290 |
+
):
|
| 291 |
+
if cache is None:
|
| 292 |
+
cache = DecodingCGCache()
|
| 293 |
+
param_example = next(iter(model.parameters()))
|
| 294 |
+
device = param_example.device
|
| 295 |
+
if dtype is None:
|
| 296 |
+
dtype = param_example.dtype
|
| 297 |
+
if (
|
| 298 |
+
(device, dtype) != (cache.device, cache.dtype)
|
| 299 |
+
or batch_size > cache.max_batch_size
|
| 300 |
+
or max_seqlen > cache.max_seqlen
|
| 301 |
+
): # Invalidate the cache
|
| 302 |
+
cache.callables = {}
|
| 303 |
+
cache.mempool = None
|
| 304 |
+
cache.inference_params = None
|
| 305 |
+
gc.collect()
|
| 306 |
+
cache.device, cache.dtype = device, dtype
|
| 307 |
+
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
| 308 |
+
assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
|
| 309 |
+
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
|
| 310 |
+
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
| 311 |
+
cache.inference_params = InferenceParams(
|
| 312 |
+
max_seqlen=max_seqlen,
|
| 313 |
+
max_batch_size=batch_size,
|
| 314 |
+
seqlen_offset=seqlen_og,
|
| 315 |
+
key_value_memory_dict=inf_cache,
|
| 316 |
+
lengths_per_sample=lengths_per_sample,
|
| 317 |
+
)
|
| 318 |
+
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
| 319 |
+
for decoding_seqlen in decoding_seqlens:
|
| 320 |
+
if (batch_size, decoding_seqlen) not in cache.callables:
|
| 321 |
+
cache.callables[batch_size, decoding_seqlen] = capture_graph(
|
| 322 |
+
model,
|
| 323 |
+
cache.inference_params,
|
| 324 |
+
batch_size,
|
| 325 |
+
max_seqlen,
|
| 326 |
+
decoding_seqlen=decoding_seqlen,
|
| 327 |
+
mempool=cache.mempool,
|
| 328 |
+
n_warmups=n_warmups,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
def dispatch(input_ids, position_ids, seqlen):
|
| 332 |
+
batch_size, decoding_seqlen = input_ids.shape[:2]
|
| 333 |
+
return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
|
| 334 |
+
|
| 335 |
+
cache.run = dispatch
|
| 336 |
+
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
|
| 337 |
+
return cache
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def capture_graph(
|
| 341 |
+
model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
|
| 342 |
+
):
|
| 343 |
+
device = next(iter(model.parameters())).device
|
| 344 |
+
input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
| 345 |
+
position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
| 346 |
+
seqlen_offset_og = inference_params.seqlen_offset
|
| 347 |
+
inference_params.seqlen_offset = max_seqlen - decoding_seqlen
|
| 348 |
+
inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
|
| 349 |
+
|
| 350 |
+
# Warmup before capture
|
| 351 |
+
s = torch.cuda.Stream()
|
| 352 |
+
s.wait_stream(torch.cuda.current_stream())
|
| 353 |
+
with torch.cuda.stream(s):
|
| 354 |
+
for _ in range(n_warmups):
|
| 355 |
+
logits = model(
|
| 356 |
+
input_ids,
|
| 357 |
+
position_ids=position_ids,
|
| 358 |
+
inference_params=inference_params,
|
| 359 |
+
num_last_tokens=decoding_seqlen,
|
| 360 |
+
).logits
|
| 361 |
+
s.synchronize()
|
| 362 |
+
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
|
| 363 |
+
# which requires that graph launch and non-captured launch to not overlap (I think,
|
| 364 |
+
# that's how I interpret the documentation). I'm not sure if this is required.
|
| 365 |
+
if torch.distributed.is_initialized():
|
| 366 |
+
torch.distributed.barrier()
|
| 367 |
+
torch.cuda.current_stream().wait_stream(s)
|
| 368 |
+
# Captures the graph
|
| 369 |
+
# To allow capture, automatically sets a side stream as the current stream in the context
|
| 370 |
+
graph = torch.cuda.CUDAGraph()
|
| 371 |
+
with torch.cuda.graph(graph, pool=mempool):
|
| 372 |
+
logits = model(
|
| 373 |
+
input_ids,
|
| 374 |
+
position_ids=position_ids,
|
| 375 |
+
inference_params=inference_params,
|
| 376 |
+
num_last_tokens=decoding_seqlen,
|
| 377 |
+
).logits
|
| 378 |
+
|
| 379 |
+
def run(new_input_ids, new_position_ids, seqlen):
|
| 380 |
+
inference_params.lengths_per_sample[:] = seqlen
|
| 381 |
+
input_ids.copy_(new_input_ids)
|
| 382 |
+
position_ids.copy_(new_position_ids)
|
| 383 |
+
graph.replay()
|
| 384 |
+
return logits.clone()
|
| 385 |
+
|
| 386 |
+
inference_params.seqlen_offset = seqlen_offset_og
|
| 387 |
+
return run
|
mamba_ssm/utils/hf.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
| 6 |
+
from transformers.utils.hub import cached_file
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_config_hf(model_name):
|
| 10 |
+
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
|
| 11 |
+
return json.load(open(resolved_archive_file))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_state_dict_hf(model_name, device=None, dtype=None):
|
| 15 |
+
# If not fp32, then we don't want to load directly to the GPU
|
| 16 |
+
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
|
| 17 |
+
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
| 18 |
+
return torch.load(resolved_archive_file, map_location=mapped_device)
|
| 19 |
+
# Convert dtype before moving to GPU to save memory
|
| 20 |
+
if dtype is not None:
|
| 21 |
+
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
|
| 22 |
+
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
| 23 |
+
return state_dict
|
models/codec_module.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/models/generator.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from .lsigmoid import LearnableSigmoid2D
|
| 7 |
+
|
| 8 |
+
def get_padding(kernel_size, dilation=1):
|
| 9 |
+
"""
|
| 10 |
+
Calculate the padding size for a convolutional layer.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
- kernel_size (int): Size of the convolutional kernel.
|
| 14 |
+
- dilation (int, optional): Dilation rate of the convolution. Defaults to 1.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
- int: Calculated padding size.
|
| 18 |
+
"""
|
| 19 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 20 |
+
|
| 21 |
+
def get_padding_2d(kernel_size, dilation=(1, 1)):
|
| 22 |
+
"""
|
| 23 |
+
Calculate the padding size for a 2D convolutional layer.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
- kernel_size (tuple): Size of the convolutional kernel (height, width).
|
| 27 |
+
- dilation (tuple, optional): Dilation rate of the convolution (height, width). Defaults to (1, 1).
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
- tuple: Calculated padding size (height, width).
|
| 31 |
+
"""
|
| 32 |
+
return (int((kernel_size[0] * dilation[0] - dilation[0]) / 2),
|
| 33 |
+
int((kernel_size[1] * dilation[1] - dilation[1]) / 2))
|
| 34 |
+
|
| 35 |
+
class DenseBlock(nn.Module):
|
| 36 |
+
"""
|
| 37 |
+
DenseBlock module consisting of multiple convolutional layers with dilation.
|
| 38 |
+
"""
|
| 39 |
+
def __init__(self, cfg, kernel_size=(3, 3), depth=4):
|
| 40 |
+
super(DenseBlock, self).__init__()
|
| 41 |
+
self.cfg = cfg
|
| 42 |
+
self.depth = depth
|
| 43 |
+
self.dense_block = nn.ModuleList()
|
| 44 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
| 45 |
+
|
| 46 |
+
for i in range(depth):
|
| 47 |
+
dil = 2 ** i
|
| 48 |
+
dense_conv = nn.Sequential(
|
| 49 |
+
nn.Conv2d(self.hid_feature * (i + 1), self.hid_feature, kernel_size,
|
| 50 |
+
dilation=(dil, 1), padding=get_padding_2d(kernel_size, (dil, 1))),
|
| 51 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
| 52 |
+
nn.PReLU(self.hid_feature)
|
| 53 |
+
)
|
| 54 |
+
self.dense_block.append(dense_conv)
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
"""
|
| 58 |
+
Forward pass for the DenseBlock module.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
- x (torch.Tensor): Input tensor.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
- torch.Tensor: Output tensor after processing through the dense block.
|
| 65 |
+
"""
|
| 66 |
+
skip = x
|
| 67 |
+
for i in range(self.depth):
|
| 68 |
+
x = self.dense_block[i](skip)
|
| 69 |
+
skip = torch.cat([x, skip], dim=1)
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
class DenseEncoder(nn.Module):
|
| 73 |
+
"""
|
| 74 |
+
DenseEncoder module consisting of initial convolution, dense block, and a final convolution.
|
| 75 |
+
"""
|
| 76 |
+
def __init__(self, cfg):
|
| 77 |
+
super(DenseEncoder, self).__init__()
|
| 78 |
+
self.cfg = cfg
|
| 79 |
+
self.input_channel = cfg['model_cfg']['input_channel']
|
| 80 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
| 81 |
+
|
| 82 |
+
self.dense_conv_1 = nn.Sequential(
|
| 83 |
+
nn.Conv2d(self.input_channel, self.hid_feature, (1, 1)),
|
| 84 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
| 85 |
+
nn.PReLU(self.hid_feature)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.dense_block = DenseBlock(cfg, depth=4)
|
| 89 |
+
|
| 90 |
+
self.dense_conv_2 = nn.Sequential(
|
| 91 |
+
nn.Conv2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
|
| 92 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
| 93 |
+
nn.PReLU(self.hid_feature)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def forward(self, x):
|
| 97 |
+
"""
|
| 98 |
+
Forward pass for the DenseEncoder module.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
- x (torch.Tensor): Input tensor.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
- torch.Tensor: Encoded tensor.
|
| 105 |
+
"""
|
| 106 |
+
x = self.dense_conv_1(x) # [batch, hid_feature, time, freq]
|
| 107 |
+
x = self.dense_block(x) # [batch, hid_feature, time, freq]
|
| 108 |
+
x = self.dense_conv_2(x) # [batch, hid_feature, time, freq//2]
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
class MagDecoder(nn.Module):
|
| 112 |
+
"""
|
| 113 |
+
MagDecoder module for decoding magnitude information.
|
| 114 |
+
"""
|
| 115 |
+
def __init__(self, cfg):
|
| 116 |
+
super(MagDecoder, self).__init__()
|
| 117 |
+
self.dense_block = DenseBlock(cfg, depth=4)
|
| 118 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
| 119 |
+
self.output_channel = cfg['model_cfg']['output_channel']
|
| 120 |
+
self.n_fft = cfg['stft_cfg']['n_fft']
|
| 121 |
+
self.beta = cfg['model_cfg']['beta']
|
| 122 |
+
|
| 123 |
+
self.mask_conv = nn.Sequential(
|
| 124 |
+
nn.ConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
|
| 125 |
+
nn.Conv2d(self.hid_feature, self.output_channel, (1, 1)),
|
| 126 |
+
nn.InstanceNorm2d(self.output_channel, affine=True),
|
| 127 |
+
nn.PReLU(self.output_channel),
|
| 128 |
+
nn.Conv2d(self.output_channel, self.output_channel, (1, 1))
|
| 129 |
+
)
|
| 130 |
+
self.lsigmoid = LearnableSigmoid2D(self.n_fft // 2 + 1, beta=self.beta)
|
| 131 |
+
|
| 132 |
+
def forward(self, x):
|
| 133 |
+
"""
|
| 134 |
+
Forward pass for the MagDecoder module.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
- x (torch.Tensor): Input tensor.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
- torch.Tensor: Decoded tensor with magnitude information.
|
| 141 |
+
"""
|
| 142 |
+
x = self.dense_block(x)
|
| 143 |
+
x = self.mask_conv(x)
|
| 144 |
+
x = rearrange(x, 'b c t f -> b f t c').squeeze(-1)
|
| 145 |
+
x = self.lsigmoid(x)
|
| 146 |
+
x = rearrange(x, 'b f t -> b t f').unsqueeze(1)
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
class PhaseDecoder(nn.Module):
|
| 150 |
+
"""
|
| 151 |
+
PhaseDecoder module for decoding phase information.
|
| 152 |
+
"""
|
| 153 |
+
def __init__(self, cfg):
|
| 154 |
+
super(PhaseDecoder, self).__init__()
|
| 155 |
+
self.dense_block = DenseBlock(cfg, depth=4)
|
| 156 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
| 157 |
+
self.output_channel = cfg['model_cfg']['output_channel']
|
| 158 |
+
|
| 159 |
+
self.phase_conv = nn.Sequential(
|
| 160 |
+
nn.ConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
|
| 161 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
| 162 |
+
nn.PReLU(self.hid_feature)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
self.phase_conv_r = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
|
| 166 |
+
self.phase_conv_i = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
|
| 167 |
+
|
| 168 |
+
def forward(self, x):
|
| 169 |
+
"""
|
| 170 |
+
Forward pass for the PhaseDecoder module.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
- x (torch.Tensor): Input tensor.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
- torch.Tensor: Decoded tensor with phase information.
|
| 177 |
+
"""
|
| 178 |
+
x = self.dense_block(x)
|
| 179 |
+
x = self.phase_conv(x)
|
| 180 |
+
x_r = self.phase_conv_r(x)
|
| 181 |
+
x_i = self.phase_conv_i(x)
|
| 182 |
+
x = torch.atan2(x_i, x_r)
|
| 183 |
+
return x
|
models/discriminator.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# References: https://github.com/yxlu-0102/MP-SENet/blob/main/models/discriminator.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
from pesq import pesq
|
| 7 |
+
from joblib import Parallel, delayed
|
| 8 |
+
from models.lsigmoid import LearnableSigmoid1D
|
| 9 |
+
|
| 10 |
+
def pesq_loss(clean, noisy, sr=16000):
|
| 11 |
+
try:
|
| 12 |
+
pesq_score = pesq(sr, clean, noisy, 'wb')
|
| 13 |
+
except:
|
| 14 |
+
# error can happen due to silent period
|
| 15 |
+
pesq_score = -1
|
| 16 |
+
return pesq_score
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def batch_pesq(clean, noisy, cfg):
|
| 20 |
+
num_worker = cfg['env_setting']['num_workers']
|
| 21 |
+
pesq_score = Parallel(n_jobs=num_worker)(delayed(pesq_loss)(c, n) for c, n in zip(clean, noisy))
|
| 22 |
+
pesq_score = np.array(pesq_score)
|
| 23 |
+
if -1 in pesq_score:
|
| 24 |
+
return None
|
| 25 |
+
pesq_score = (pesq_score - 1) / 3.5
|
| 26 |
+
return torch.FloatTensor(pesq_score)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class MetricDiscriminator(nn.Module):
|
| 30 |
+
def __init__(self, dim=16, in_channel=2):
|
| 31 |
+
super(MetricDiscriminator, self).__init__()
|
| 32 |
+
self.layers = nn.Sequential(
|
| 33 |
+
nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
|
| 34 |
+
nn.InstanceNorm2d(dim, affine=True),
|
| 35 |
+
nn.PReLU(dim),
|
| 36 |
+
nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
|
| 37 |
+
nn.InstanceNorm2d(dim*2, affine=True),
|
| 38 |
+
nn.PReLU(dim*2),
|
| 39 |
+
nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
|
| 40 |
+
nn.InstanceNorm2d(dim*4, affine=True),
|
| 41 |
+
nn.PReLU(dim*4),
|
| 42 |
+
nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
|
| 43 |
+
nn.InstanceNorm2d(dim*8, affine=True),
|
| 44 |
+
nn.PReLU(dim*8),
|
| 45 |
+
nn.AdaptiveMaxPool2d(1),
|
| 46 |
+
nn.Flatten(),
|
| 47 |
+
nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
|
| 48 |
+
nn.Dropout(0.3),
|
| 49 |
+
nn.PReLU(dim*4),
|
| 50 |
+
nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
|
| 51 |
+
LearnableSigmoid1D(1)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(self, x, y):
|
| 55 |
+
xy = torch.stack((x, y), dim=1)
|
| 56 |
+
return self.layers(xy)
|
models/generator.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
from .mamba_block import TFMambaBlock
|
| 5 |
+
from .codec_module import DenseEncoder, MagDecoder, PhaseDecoder
|
| 6 |
+
|
| 7 |
+
class SEMamba(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
SEMamba model for speech enhancement using Mamba blocks.
|
| 10 |
+
|
| 11 |
+
This model uses a dense encoder, multiple Mamba blocks, and separate magnitude
|
| 12 |
+
and phase decoders to process noisy magnitude and phase inputs.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, cfg):
|
| 15 |
+
"""
|
| 16 |
+
Initialize the SEMamba model.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
- cfg: Configuration object containing model parameters.
|
| 20 |
+
"""
|
| 21 |
+
super(SEMamba, self).__init__()
|
| 22 |
+
self.cfg = cfg
|
| 23 |
+
self.num_tscblocks = cfg['model_cfg']['num_tfmamba'] if cfg['model_cfg']['num_tfmamba'] is not None else 4 # default tfmamba: 4
|
| 24 |
+
|
| 25 |
+
# Initialize dense encoder
|
| 26 |
+
self.dense_encoder = DenseEncoder(cfg)
|
| 27 |
+
|
| 28 |
+
# Initialize Mamba blocks
|
| 29 |
+
self.TSMamba = nn.ModuleList([TFMambaBlock(cfg) for _ in range(self.num_tscblocks)])
|
| 30 |
+
|
| 31 |
+
# Initialize decoders
|
| 32 |
+
self.mask_decoder = MagDecoder(cfg)
|
| 33 |
+
self.phase_decoder = PhaseDecoder(cfg)
|
| 34 |
+
|
| 35 |
+
def forward(self, noisy_mag, noisy_pha):
|
| 36 |
+
"""
|
| 37 |
+
Forward pass for the SEMamba model.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
- noisy_mag (torch.Tensor): Noisy magnitude input tensor [B, F, T].
|
| 41 |
+
- noisy_pha (torch.Tensor): Noisy phase input tensor [B, F, T].
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
- denoised_mag (torch.Tensor): Denoised magnitude tensor [B, F, T].
|
| 45 |
+
- denoised_pha (torch.Tensor): Denoised phase tensor [B, F, T].
|
| 46 |
+
- denoised_com (torch.Tensor): Denoised complex tensor [B, F, T, 2].
|
| 47 |
+
"""
|
| 48 |
+
# Reshape inputs
|
| 49 |
+
noisy_mag = rearrange(noisy_mag, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F]
|
| 50 |
+
noisy_pha = rearrange(noisy_pha, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F]
|
| 51 |
+
|
| 52 |
+
# Concatenate magnitude and phase inputs
|
| 53 |
+
x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F]
|
| 54 |
+
|
| 55 |
+
# Encode input
|
| 56 |
+
x = self.dense_encoder(x)
|
| 57 |
+
|
| 58 |
+
# Apply Mamba blocks
|
| 59 |
+
for block in self.TSMamba:
|
| 60 |
+
x = block(x)
|
| 61 |
+
|
| 62 |
+
# Decode magnitude and phase
|
| 63 |
+
denoised_mag = rearrange(self.mask_decoder(x) * noisy_mag, 'b c t f -> b f t c').squeeze(-1)
|
| 64 |
+
denoised_pha = rearrange(self.phase_decoder(x), 'b c t f -> b f t c').squeeze(-1)
|
| 65 |
+
|
| 66 |
+
# Combine denoised magnitude and phase into a complex representation
|
| 67 |
+
denoised_com = torch.stack(
|
| 68 |
+
(denoised_mag * torch.cos(denoised_pha), denoised_mag * torch.sin(denoised_pha)),
|
| 69 |
+
dim=-1
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
return denoised_mag, denoised_pha, denoised_com
|
models/loss.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/models/generator.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
from pesq import pesq
|
| 7 |
+
from joblib import Parallel, delayed
|
| 8 |
+
|
| 9 |
+
def phase_losses(phase_r, phase_g, cfg):
|
| 10 |
+
"""
|
| 11 |
+
Calculate phase losses including in-phase loss, gradient delay loss,
|
| 12 |
+
and integrated absolute frequency loss between reference and generated phases.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
phase_r (torch.Tensor): Reference phase tensor of shape (batch, freq, time).
|
| 16 |
+
phase_g (torch.Tensor): Generated phase tensor of shape (batch, freq, time).
|
| 17 |
+
h (object): Configuration object containing parameters like n_fft.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
tuple: Tuple containing in-phase loss, gradient delay loss, and integrated absolute frequency loss.
|
| 21 |
+
"""
|
| 22 |
+
dim_freq = cfg['stft_cfg']['n_fft'] // 2 + 1 # Calculate frequency dimension
|
| 23 |
+
dim_time = phase_r.size(-1) # Calculate time dimension
|
| 24 |
+
|
| 25 |
+
# Construct gradient delay matrix
|
| 26 |
+
gd_matrix = (torch.triu(torch.ones(dim_freq, dim_freq), diagonal=1) -
|
| 27 |
+
torch.triu(torch.ones(dim_freq, dim_freq), diagonal=2) -
|
| 28 |
+
torch.eye(dim_freq)).to(phase_g.device)
|
| 29 |
+
|
| 30 |
+
# Apply gradient delay matrix to reference and generated phases
|
| 31 |
+
gd_r = torch.matmul(phase_r.permute(0, 2, 1), gd_matrix)
|
| 32 |
+
gd_g = torch.matmul(phase_g.permute(0, 2, 1), gd_matrix)
|
| 33 |
+
|
| 34 |
+
# Construct integrated absolute frequency matrix
|
| 35 |
+
iaf_matrix = (torch.triu(torch.ones(dim_time, dim_time), diagonal=1) -
|
| 36 |
+
torch.triu(torch.ones(dim_time, dim_time), diagonal=2) -
|
| 37 |
+
torch.eye(dim_time)).to(phase_g.device)
|
| 38 |
+
|
| 39 |
+
# Apply integrated absolute frequency matrix to reference and generated phases
|
| 40 |
+
iaf_r = torch.matmul(phase_r, iaf_matrix)
|
| 41 |
+
iaf_g = torch.matmul(phase_g, iaf_matrix)
|
| 42 |
+
|
| 43 |
+
# Calculate losses
|
| 44 |
+
ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
|
| 45 |
+
gd_loss = torch.mean(anti_wrapping_function(gd_r - gd_g))
|
| 46 |
+
iaf_loss = torch.mean(anti_wrapping_function(iaf_r - iaf_g))
|
| 47 |
+
|
| 48 |
+
return ip_loss, gd_loss, iaf_loss
|
| 49 |
+
|
| 50 |
+
def anti_wrapping_function(x):
|
| 51 |
+
"""
|
| 52 |
+
Anti-wrapping function to adjust phase values within the range of -pi to pi.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
x (torch.Tensor): Input tensor representing phase differences.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
torch.Tensor: Adjusted tensor with phase values wrapped within -pi to pi.
|
| 59 |
+
"""
|
| 60 |
+
return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
|
| 61 |
+
|
| 62 |
+
def compute_stft(y: torch.Tensor, n_fft: int, hop_size: int, win_size: int, center: bool, compress_factor: float = 1.0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 63 |
+
"""
|
| 64 |
+
Compute the Short-Time Fourier Transform (STFT) and return magnitude, phase, and complex components.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
y (torch.Tensor): Input signal tensor.
|
| 68 |
+
n_fft (int): Number of FFT points.
|
| 69 |
+
hop_size (int): Hop size for STFT.
|
| 70 |
+
win_size (int): Window size for STFT.
|
| 71 |
+
center (bool): Whether to pad the input on both sides.
|
| 72 |
+
compress_factor (float, optional): Compression factor for magnitude. Defaults to 1.0.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Magnitude, phase, and complex components.
|
| 76 |
+
"""
|
| 77 |
+
eps = torch.finfo(y.dtype).eps
|
| 78 |
+
hann_window = torch.hann_window(win_size).to(y.device)
|
| 79 |
+
|
| 80 |
+
stft_spec = torch.stft(
|
| 81 |
+
y,
|
| 82 |
+
n_fft=n_fft,
|
| 83 |
+
hop_length=hop_size,
|
| 84 |
+
win_length=win_size,
|
| 85 |
+
window=hann_window,
|
| 86 |
+
center=center,
|
| 87 |
+
pad_mode='reflect',
|
| 88 |
+
normalized=False,
|
| 89 |
+
return_complex=True
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
real_part = stft_spec.real
|
| 93 |
+
imag_part = stft_spec.imag
|
| 94 |
+
|
| 95 |
+
mag = torch.sqrt( real_part.pow(2) * imag_part.pow(2) + eps )
|
| 96 |
+
pha = torch.atan2( real_part + eps, imag_part + eps )
|
| 97 |
+
|
| 98 |
+
mag = torch.pow(mag, compress_factor)
|
| 99 |
+
com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1)
|
| 100 |
+
|
| 101 |
+
return mag, pha, com
|
| 102 |
+
|
| 103 |
+
def pesq_score(utts_r, utts_g, cfg):
|
| 104 |
+
"""
|
| 105 |
+
Calculate PESQ (Perceptual Evaluation of Speech Quality) score for pairs of reference and generated utterances.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
utts_r (list of torch.Tensor): List of reference utterances.
|
| 109 |
+
utts_g (list of torch.Tensor): List of generated utterances.
|
| 110 |
+
h (object): Configuration object containing parameters like sampling_rate.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
float: Mean PESQ score across all pairs of utterances.
|
| 114 |
+
"""
|
| 115 |
+
def eval_pesq(clean_utt, esti_utt, sr):
|
| 116 |
+
"""
|
| 117 |
+
Evaluate PESQ score for a single pair of clean and estimated utterances.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
clean_utt (np.ndarray): Clean reference utterance.
|
| 121 |
+
esti_utt (np.ndarray): Estimated generated utterance.
|
| 122 |
+
sr (int): Sampling rate.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
float: PESQ score or -1 in case of an error.
|
| 126 |
+
"""
|
| 127 |
+
try:
|
| 128 |
+
pesq_score = pesq(sr, clean_utt, esti_utt)
|
| 129 |
+
except Exception as e:
|
| 130 |
+
# Error can happen due to silent period or other issues
|
| 131 |
+
print(f"Error computing PESQ score: {e}")
|
| 132 |
+
pesq_score = -1
|
| 133 |
+
return pesq_score
|
| 134 |
+
|
| 135 |
+
# Parallel processing of PESQ score computation
|
| 136 |
+
pesq_scores = Parallel(n_jobs=30)(delayed(eval_pesq)(
|
| 137 |
+
utts_r[i].squeeze().cpu().numpy(),
|
| 138 |
+
utts_g[i].squeeze().cpu().numpy(),
|
| 139 |
+
cfg['stft_cfg']['sampling_rate']
|
| 140 |
+
) for i in range(len(utts_r)))
|
| 141 |
+
|
| 142 |
+
# Calculate mean PESQ score
|
| 143 |
+
pesq_score = np.mean(pesq_scores)
|
| 144 |
+
return pesq_score
|
| 145 |
+
|
models/lsigmoid.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/utils.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
class LearnableSigmoid1D(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Learnable Sigmoid Activation Function for 1D inputs.
|
| 9 |
+
|
| 10 |
+
This module applies a learnable slope parameter to the sigmoid activation function.
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, in_features, beta=1):
|
| 13 |
+
"""
|
| 14 |
+
Initialize the LearnableSigmoid1D module.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
- in_features (int): Number of input features.
|
| 18 |
+
- beta (float, optional): Scaling factor for the sigmoid function. Defaults to 1.
|
| 19 |
+
"""
|
| 20 |
+
super(LearnableSigmoid1D, self).__init__()
|
| 21 |
+
self.beta = beta
|
| 22 |
+
self.slope = nn.Parameter(torch.ones(in_features))
|
| 23 |
+
self.slope.requires_grad = True
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
"""
|
| 27 |
+
Forward pass for the LearnableSigmoid1D module.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
- x (torch.Tensor): Input tensor.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
- torch.Tensor: Output tensor after applying the learnable sigmoid activation.
|
| 34 |
+
"""
|
| 35 |
+
return self.beta * torch.sigmoid(self.slope * x)
|
| 36 |
+
|
| 37 |
+
class LearnableSigmoid2D(nn.Module):
|
| 38 |
+
"""
|
| 39 |
+
Learnable Sigmoid Activation Function for 2D inputs.
|
| 40 |
+
|
| 41 |
+
This module applies a learnable slope parameter to the sigmoid activation function for 2D inputs.
|
| 42 |
+
"""
|
| 43 |
+
def __init__(self, in_features, beta=1):
|
| 44 |
+
"""
|
| 45 |
+
Initialize the LearnableSigmoid2D module.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
- in_features (int): Number of input features.
|
| 49 |
+
- beta (float, optional): Scaling factor for the sigmoid function. Defaults to 1.
|
| 50 |
+
"""
|
| 51 |
+
super(LearnableSigmoid2D, self).__init__()
|
| 52 |
+
self.beta = beta
|
| 53 |
+
self.slope = nn.Parameter(torch.ones(in_features, 1))
|
| 54 |
+
self.slope.requires_grad = True
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
"""
|
| 58 |
+
Forward pass for the LearnableSigmoid2D module.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
- x (torch.Tensor): Input tensor.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
- torch.Tensor: Output tensor after applying the learnable sigmoid activation.
|
| 65 |
+
"""
|
| 66 |
+
return self.beta * torch.sigmoid(self.slope * x)
|
models/mamba_block.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reference: https://github.com/state-spaces/mamba/blob/9127d1f47f367f5c9cc49c73ad73557089d02cb8/mamba_ssm/models/mixer_seq_simple.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch.nn import init
|
| 7 |
+
from torch.nn.parameter import Parameter
|
| 8 |
+
from functools import partial
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
|
| 11 |
+
from mamba_ssm.modules.mamba_simple import Mamba, Block
|
| 12 |
+
from mamba_ssm.models.mixer_seq_simple import _init_weights
|
| 13 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm
|
| 14 |
+
|
| 15 |
+
# github: https://github.com/state-spaces/mamba/blob/9127d1f47f367f5c9cc49c73ad73557089d02cb8/mamba_ssm/models/mixer_seq_simple.py
|
| 16 |
+
def create_block(
|
| 17 |
+
d_model, cfg, layer_idx=0, rms_norm=True, fused_add_norm=False, residual_in_fp32=False,
|
| 18 |
+
):
|
| 19 |
+
d_state = cfg['model_cfg']['d_state'] # 16
|
| 20 |
+
d_conv = cfg['model_cfg']['d_conv'] # 4
|
| 21 |
+
expand = cfg['model_cfg']['expand'] # 4
|
| 22 |
+
norm_epsilon = cfg['model_cfg']['norm_epsilon'] # 0.00001
|
| 23 |
+
|
| 24 |
+
mixer_cls = partial(Mamba, layer_idx=layer_idx, d_state=d_state, d_conv=d_conv, expand=expand)
|
| 25 |
+
norm_cls = partial(
|
| 26 |
+
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon
|
| 27 |
+
)
|
| 28 |
+
block = Block(
|
| 29 |
+
d_model,
|
| 30 |
+
mixer_cls,
|
| 31 |
+
norm_cls=norm_cls,
|
| 32 |
+
fused_add_norm=fused_add_norm,
|
| 33 |
+
residual_in_fp32=residual_in_fp32,
|
| 34 |
+
)
|
| 35 |
+
block.layer_idx = layer_idx
|
| 36 |
+
return block
|
| 37 |
+
|
| 38 |
+
class MambaBlock(nn.Module):
|
| 39 |
+
def __init__(self, in_channels, cfg):
|
| 40 |
+
super(MambaBlock, self).__init__()
|
| 41 |
+
n_layer = 1
|
| 42 |
+
self.forward_blocks = nn.ModuleList( create_block(in_channels, cfg) for i in range(n_layer) )
|
| 43 |
+
self.backward_blocks = nn.ModuleList( create_block(in_channels, cfg) for i in range(n_layer) )
|
| 44 |
+
|
| 45 |
+
self.apply(
|
| 46 |
+
partial(
|
| 47 |
+
_init_weights,
|
| 48 |
+
n_layer=n_layer,
|
| 49 |
+
)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
x_forward, x_backward = x.clone(), torch.flip(x, [1])
|
| 54 |
+
resi_forward, resi_backward = None, None
|
| 55 |
+
|
| 56 |
+
# Forward
|
| 57 |
+
for layer in self.forward_blocks:
|
| 58 |
+
x_forward, resi_forward = layer(x_forward, resi_forward)
|
| 59 |
+
y_forward = (x_forward + resi_forward) if resi_forward is not None else x_forward
|
| 60 |
+
|
| 61 |
+
# Backward
|
| 62 |
+
for layer in self.backward_blocks:
|
| 63 |
+
x_backward, resi_backward = layer(x_backward, resi_backward)
|
| 64 |
+
y_backward = torch.flip((x_backward + resi_backward), [1]) if resi_backward is not None else torch.flip(x_backward, [1])
|
| 65 |
+
|
| 66 |
+
return torch.cat([y_forward, y_backward], -1)
|
| 67 |
+
|
| 68 |
+
class TFMambaBlock(nn.Module):
|
| 69 |
+
"""
|
| 70 |
+
Temporal-Frequency Mamba block for sequence modeling.
|
| 71 |
+
|
| 72 |
+
Attributes:
|
| 73 |
+
cfg (Config): Configuration for the block.
|
| 74 |
+
time_mamba (MambaBlock): Mamba block for temporal dimension.
|
| 75 |
+
freq_mamba (MambaBlock): Mamba block for frequency dimension.
|
| 76 |
+
tlinear (ConvTranspose1d): ConvTranspose1d layer for temporal dimension.
|
| 77 |
+
flinear (ConvTranspose1d): ConvTranspose1d layer for frequency dimension.
|
| 78 |
+
"""
|
| 79 |
+
def __init__(self, cfg):
|
| 80 |
+
super(TFMambaBlock, self).__init__()
|
| 81 |
+
self.cfg = cfg
|
| 82 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
| 83 |
+
|
| 84 |
+
# Initialize Mamba blocks
|
| 85 |
+
self.time_mamba = MambaBlock(in_channels=self.hid_feature, cfg=cfg)
|
| 86 |
+
self.freq_mamba = MambaBlock(in_channels=self.hid_feature, cfg=cfg)
|
| 87 |
+
|
| 88 |
+
# Initialize ConvTranspose1d layers
|
| 89 |
+
self.tlinear = nn.ConvTranspose1d(self.hid_feature * 2, self.hid_feature, 1, stride=1)
|
| 90 |
+
self.flinear = nn.ConvTranspose1d(self.hid_feature * 2, self.hid_feature, 1, stride=1)
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
"""
|
| 94 |
+
Forward pass of the TFMamba block.
|
| 95 |
+
|
| 96 |
+
Parameters:
|
| 97 |
+
x (Tensor): Input tensor with shape (batch, channels, time, freq).
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Tensor: Output tensor after applying temporal and frequency Mamba blocks.
|
| 101 |
+
"""
|
| 102 |
+
b, c, t, f = x.size()
|
| 103 |
+
|
| 104 |
+
x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
|
| 105 |
+
x = self.tlinear( self.time_mamba(x).permute(0,2,1) ).permute(0,2,1) + x
|
| 106 |
+
x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
|
| 107 |
+
x = self.flinear( self.freq_mamba(x).permute(0,2,1) ).permute(0,2,1) + x
|
| 108 |
+
x = x.view(b, t, f, c).permute(0, 3, 1, 2)
|
| 109 |
+
return x
|
| 110 |
+
|
models/pcs400.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchaudio
|
| 4 |
+
import numpy as np
|
| 5 |
+
import argparse
|
| 6 |
+
import librosa
|
| 7 |
+
import scipy
|
| 8 |
+
|
| 9 |
+
# PCS400 parameters
|
| 10 |
+
PCS400 = np.ones(201)
|
| 11 |
+
PCS400[0:3] = 1
|
| 12 |
+
PCS400[3:5] = 1.070175439
|
| 13 |
+
PCS400[5:8] = 1.182456140
|
| 14 |
+
PCS400[8:10] = 1.287719298
|
| 15 |
+
PCS400[10:110] = 1.4 # Pre Set
|
| 16 |
+
PCS400[110:130] = 1.322807018
|
| 17 |
+
PCS400[130:160] = 1.238596491
|
| 18 |
+
PCS400[160:190] = 1.161403509
|
| 19 |
+
PCS400[190:202] = 1.077192982
|
| 20 |
+
|
| 21 |
+
maxv = np.iinfo(np.int16).max
|
| 22 |
+
|
| 23 |
+
def Sp_and_phase(signal):
|
| 24 |
+
signal_length = signal.shape[0]
|
| 25 |
+
n_fft = 400
|
| 26 |
+
hop_length = 100
|
| 27 |
+
y_pad = librosa.util.fix_length(signal, size=signal_length + n_fft // 2)
|
| 28 |
+
|
| 29 |
+
F = librosa.stft(y_pad, n_fft=400, hop_length=100, win_length=400, window=scipy.signal.windows.hamming(400))
|
| 30 |
+
Lp = PCS400 * np.transpose(np.log1p(np.abs(F)), (1, 0))
|
| 31 |
+
phase = np.angle(F)
|
| 32 |
+
|
| 33 |
+
NLp = np.transpose(Lp, (1, 0))
|
| 34 |
+
|
| 35 |
+
return NLp, phase, signal_length
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def SP_to_wav(mag, phase, signal_length):
|
| 39 |
+
mag = np.expm1(mag)
|
| 40 |
+
Rec = np.multiply(mag, np.exp(1j*phase))
|
| 41 |
+
result = librosa.istft(Rec,
|
| 42 |
+
hop_length=100,
|
| 43 |
+
win_length=400,
|
| 44 |
+
window=scipy.signal.windows.hamming(400),
|
| 45 |
+
length=signal_length)
|
| 46 |
+
return result
|
| 47 |
+
|
| 48 |
+
def cal_pcs(signal_wav):
|
| 49 |
+
noisy_LP, Nphase, signal_length = Sp_and_phase(signal_wav.squeeze())
|
| 50 |
+
enhanced_wav = SP_to_wav(noisy_LP, Nphase, signal_length)
|
| 51 |
+
enhanced_wav = enhanced_wav/np.max(abs(enhanced_wav))
|
| 52 |
+
|
| 53 |
+
return enhanced_wav
|
models/stfts.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
def mag_phase_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True, addeps=False):
|
| 5 |
+
"""
|
| 6 |
+
Compute magnitude and phase using STFT.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
y (torch.Tensor): Input audio signal.
|
| 10 |
+
n_fft (int): FFT size.
|
| 11 |
+
hop_size (int): Hop size.
|
| 12 |
+
win_size (int): Window size.
|
| 13 |
+
compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0.
|
| 14 |
+
center (bool, optional): Whether to center the signal before padding. Defaults to True.
|
| 15 |
+
eps (bool, optional): Whether adding epsilon to magnitude and phase or not. Defaults to False.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
tuple: Magnitude, phase, and complex representation of the STFT.
|
| 19 |
+
"""
|
| 20 |
+
#eps = torch.finfo(y.dtype).eps
|
| 21 |
+
eps = 1e-10
|
| 22 |
+
hann_window = torch.hann_window(win_size).to(y.device)
|
| 23 |
+
stft_spec = torch.stft(
|
| 24 |
+
y, n_fft,
|
| 25 |
+
hop_length=hop_size,
|
| 26 |
+
win_length=win_size,
|
| 27 |
+
window=hann_window,
|
| 28 |
+
center=center,
|
| 29 |
+
pad_mode='reflect',
|
| 30 |
+
normalized=False,
|
| 31 |
+
return_complex=True)
|
| 32 |
+
|
| 33 |
+
if addeps==False:
|
| 34 |
+
mag = torch.abs(stft_spec)
|
| 35 |
+
pha = torch.angle(stft_spec)
|
| 36 |
+
else:
|
| 37 |
+
real_part = stft_spec.real
|
| 38 |
+
imag_part = stft_spec.imag
|
| 39 |
+
mag = torch.sqrt(real_part.pow(2) + imag_part.pow(2) + eps)
|
| 40 |
+
pha = torch.atan2(imag_part + eps, real_part + eps)
|
| 41 |
+
# Compress the magnitude
|
| 42 |
+
mag = torch.pow(mag, compress_factor)
|
| 43 |
+
com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1)
|
| 44 |
+
return mag, pha, com
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def mag_phase_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
|
| 48 |
+
"""
|
| 49 |
+
Inverse STFT to reconstruct the audio signal from magnitude and phase.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
mag (torch.Tensor): Magnitude of the STFT.
|
| 53 |
+
pha (torch.Tensor): Phase of the STFT.
|
| 54 |
+
n_fft (int): FFT size.
|
| 55 |
+
hop_size (int): Hop size.
|
| 56 |
+
win_size (int): Window size.
|
| 57 |
+
compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0.
|
| 58 |
+
center (bool, optional): Whether to center the signal before padding. Defaults to True.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
torch.Tensor: Reconstructed audio signal.
|
| 62 |
+
"""
|
| 63 |
+
mag = torch.pow(mag, 1.0 / compress_factor)
|
| 64 |
+
com = torch.complex(mag * torch.cos(pha), mag * torch.sin(pha))
|
| 65 |
+
hann_window = torch.hann_window(win_size).to(com.device)
|
| 66 |
+
wav = torch.istft(
|
| 67 |
+
com,
|
| 68 |
+
n_fft,
|
| 69 |
+
hop_length=hop_size,
|
| 70 |
+
win_length=win_size,
|
| 71 |
+
window=hann_window,
|
| 72 |
+
center=center)
|
| 73 |
+
return wav
|
recipes/SEMamba_advanced.yaml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Environment Settings
|
| 2 |
+
# These settings specify the hardware and distributed setup for the model training.
|
| 3 |
+
# Adjust `num_gpus` and `dist_config` according to your distributed training environment.
|
| 4 |
+
env_setting:
|
| 5 |
+
num_gpus: 2 # Number of GPUs. Now we don't support CPU mode.
|
| 6 |
+
num_workers: 20 # Number of worker threads for data loading.
|
| 7 |
+
seed: 1234 # Seed for random number generators to ensure reproducibility.
|
| 8 |
+
stdout_interval: 10
|
| 9 |
+
checkpoint_interval: 1000 # save model to ckpt every N steps
|
| 10 |
+
validation_interval: 1000
|
| 11 |
+
summary_interval: 100
|
| 12 |
+
dist_cfg:
|
| 13 |
+
dist_backend: nccl # Distributed training backend, 'nccl' for NVIDIA GPUs.
|
| 14 |
+
dist_url: tcp://localhost:19477 # URL for initializing distributed training.
|
| 15 |
+
world_size: 1 # Total number of processes in the distributed training.
|
| 16 |
+
|
| 17 |
+
# Datapath Configuratoin
|
| 18 |
+
data_cfg:
|
| 19 |
+
train_clean_json: data/train_clean.json
|
| 20 |
+
train_noisy_json: data/train_noisy.json
|
| 21 |
+
valid_clean_json: data/valid_clean.json
|
| 22 |
+
valid_noisy_json: data/valid_noisy.json
|
| 23 |
+
test_clean_json: data/test_clean.json
|
| 24 |
+
test_noisy_json: data/test_noisy.json
|
| 25 |
+
|
| 26 |
+
# Training Configuration
|
| 27 |
+
# This section details parameters that directly influence the training process,
|
| 28 |
+
# including batch sizes, learning rates, and optimizer specifics.
|
| 29 |
+
training_cfg:
|
| 30 |
+
training_epochs: 200 # Training epoch.
|
| 31 |
+
batch_size: 4 # Training batch size.
|
| 32 |
+
learning_rate: 0.0005 # Initial learning rate.
|
| 33 |
+
adam_b1: 0.8 # Beta1 hyperparameter for the AdamW optimizer.
|
| 34 |
+
adam_b2: 0.99 # Beta2 hyperparameter for the AdamW optimizer.
|
| 35 |
+
lr_decay: 0.99 # Learning rate decay per epoch.
|
| 36 |
+
segment_size: 32000 # Audio segment size used during training, dependent on sampling rate.
|
| 37 |
+
loss:
|
| 38 |
+
metric: 0.05
|
| 39 |
+
magnitude: 0.9
|
| 40 |
+
phase: 0.3
|
| 41 |
+
complex: 0.1
|
| 42 |
+
time: 0.2
|
| 43 |
+
consistancy: 0.1
|
| 44 |
+
use_PCS400: False # Use PCS or not
|
| 45 |
+
|
| 46 |
+
# STFT Configuration
|
| 47 |
+
# Configuration for Short-Time Fourier Transform (STFT), crucial for audio processing models.
|
| 48 |
+
stft_cfg:
|
| 49 |
+
sampling_rate: 16000 # Audio sampling rate in Hz.
|
| 50 |
+
n_fft: 400 # FFT components for transforming audio signals.
|
| 51 |
+
hop_size: 100 # Samples between successive frames.
|
| 52 |
+
win_size: 400 # Window size used in FFT.
|
| 53 |
+
|
| 54 |
+
# Model Configuration
|
| 55 |
+
# Defines the architecture specifics of the model, including layer configurations and feature compression.
|
| 56 |
+
model_cfg:
|
| 57 |
+
hid_feature: 64 # Channels in dense layers.
|
| 58 |
+
compress_factor: 0.3 # Compression factor applied to extracted features.
|
| 59 |
+
num_tfmamba: 4 # Number of Time-Frequency Mamba (TFMamba) blocks in the model.
|
| 60 |
+
d_state: 16 # Dimensionality of the state vector in Mamba blocks.
|
| 61 |
+
d_conv: 4 # Convolutional layer dimensionality within Mamba blocks.
|
| 62 |
+
expand: 4 # Expansion factor for the layers within the Mamba blocks.
|
| 63 |
+
norm_epsilon: 0.00001 # Numerical stability in normalization layers within the Mamba blocks.
|
| 64 |
+
beta: 2.0 # Hyperparameter for the Learnable Sigmoid function.
|
| 65 |
+
input_channel: 2 # Magnitude and Phase
|
| 66 |
+
output_channel: 1 # Single Channel Speech Enhancement
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
packaging
|
| 2 |
+
librosa
|
| 3 |
+
soundfile
|
| 4 |
+
pyyaml
|
| 5 |
+
argparse
|
| 6 |
+
tensorboard
|
| 7 |
+
pesq
|
| 8 |
+
einops
|
| 9 |
+
matplotlib
|
| 10 |
+
torch==2.5.1
|
| 11 |
+
torchaudio==2.5.1
|
| 12 |
+
numpy==1.26.4
|
| 13 |
+
ultralytics
|
| 14 |
+
moviepy
|
| 15 |
+
supervision
|
| 16 |
+
opencv-python
|
| 17 |
+
ffmpeg-python
|
| 18 |
+
decord==0.6.0
|
| 19 |
+
pytorch_lightning==1.9.0
|
| 20 |
+
typeguard==2.13.3
|
| 21 |
+
torch_complex
|
| 22 |
+
rich
|
yolov8n-face.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d17b38523a994b13ee604b67f02791ca0f43b9f446a32fd7bc44e17c56ead077
|
| 3 |
+
size 6250099
|