|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import time |
|
from dataclasses import dataclass |
|
|
|
import ipdb |
|
import torch |
|
from modules.flash_attn import FlashAttention |
|
from modules.lite_mla import LiteMLA |
|
from modules.triton_lite_mla import TritonLiteMLA |
|
from modules.triton_lite_mla_fwd import TritonLiteMLAFwd |
|
from modules.utils.dtype import get_dtype_from_str |
|
from modules.utils.export_onnx import export_onnx |
|
from omegaconf import OmegaConf |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from torchprofile import profile_macs |
|
|
|
|
|
@dataclass |
|
class DevelopTritonLiteMLAConfig: |
|
batch_size: int = 16 |
|
input_size: int = 1024 // 8 // 2 |
|
num_channels: int = 1152 |
|
num_heads: int = 36 |
|
attn_type: str = "LiteMLA" |
|
|
|
device: str = "cuda" |
|
dtype: str = "fp16" |
|
|
|
profile_macs: bool = False |
|
test_correctness: bool = False |
|
warmup_iterations: int = 50 |
|
iterations: int = 1000 |
|
random_weight: bool = True |
|
backward: bool = False |
|
autocast: bool = False |
|
use_cuda_graph: bool = False |
|
|
|
export_model: bool = False |
|
opset: int = 17 |
|
export_path: str = "" |
|
export_dtype: str = "fp32" |
|
export_device: str = "cuda" |
|
|
|
|
|
def simulate_litemla( |
|
x: torch.Tensor, |
|
qkv_weight: torch.Tensor, |
|
proj_weight: torch.Tensor, |
|
proj_bias: torch.Tensor, |
|
num_heads: int, |
|
head_dim: int, |
|
eps: float, |
|
backward: bool, |
|
): |
|
B, N, C = x.shape |
|
qkv = F.linear(x, qkv_weight).reshape(B, N, 3, C).permute(0, 2, 3, 1) |
|
q, k, v = qkv.unbind(1) |
|
|
|
q = q.reshape(B, C // head_dim, head_dim, N) |
|
k = k.reshape(B, C // head_dim, head_dim, N).transpose(-1, -2) |
|
v = v.reshape(B, C // head_dim, head_dim, N) |
|
|
|
q = F.relu(q) |
|
k = F.relu(k) |
|
|
|
q, k, v = q.float(), k.float(), v.float() |
|
if backward: |
|
k.retain_grad() |
|
v.retain_grad() |
|
q.retain_grad() |
|
v_pad = F.pad(v, (0, 0, 0, 1), mode="constant", value=1) |
|
vk = torch.matmul(v_pad, k) |
|
if backward: |
|
vk.retain_grad() |
|
vk_q = torch.matmul(vk, q) |
|
vk_q_numerator, vk_q_denominator = vk_q[:, :, :-1], vk_q[:, :, -1:] |
|
if backward: |
|
vk_q_numerator.retain_grad() |
|
vk_q_denominator.retain_grad() |
|
vk_q_divide = (vk_q_numerator / (vk_q_denominator + eps)).to(x.dtype) |
|
|
|
proj_input = vk_q_divide.view(B, C, N).permute(0, 2, 1) |
|
if backward: |
|
proj_input.retain_grad() |
|
y = F.linear(proj_input, proj_weight, proj_bias) |
|
output_dict = { |
|
"q": q, |
|
"k": k, |
|
"v": v, |
|
"vk": vk, |
|
"proj_input": proj_input, |
|
"vk_q_numerator": vk_q_numerator, |
|
"vk_q_denominator": vk_q_denominator, |
|
"vk_q_divide": vk_q_divide, |
|
"y": y, |
|
} |
|
return output_dict |
|
|
|
|
|
def main(): |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
LiteMLA.fp32_attention = True |
|
torch.cuda.manual_seed(0) |
|
torch.manual_seed(0) |
|
|
|
cfg = OmegaConf.structured(DevelopTritonLiteMLAConfig) |
|
cli_cfg = OmegaConf.from_cli() |
|
cfg = OmegaConf.merge(cfg, OmegaConf.masked_copy(cli_cfg, cfg.keys())) |
|
cfg: DevelopTritonLiteMLAConfig = OmegaConf.to_object(cfg) |
|
|
|
torch.set_grad_enabled(cfg.backward) |
|
|
|
device = torch.device("cuda") |
|
if cfg.autocast: |
|
dtype = torch.float32 |
|
autocast_dtype = get_dtype_from_str(cfg.dtype) |
|
else: |
|
dtype = get_dtype_from_str(cfg.dtype) |
|
autocast_dtype = None |
|
|
|
if cfg.attn_type == "LiteMLA": |
|
block = LiteMLA(cfg.num_channels, cfg.num_channels, dim=cfg.num_channels // cfg.num_heads, eps=1e-8) |
|
elif cfg.attn_type == "TritonLiteMLA": |
|
block = TritonLiteMLA(cfg.num_channels, cfg.num_heads, eps=1e-8) |
|
elif cfg.attn_type == "TritonLiteMLAFwd": |
|
block = TritonLiteMLAFwd(cfg.num_channels, cfg.num_heads, eps=1e-8) |
|
elif cfg.attn_type == "FlashAttention": |
|
block = FlashAttention(cfg.num_channels, cfg.num_heads) |
|
else: |
|
raise NotImplementedError |
|
|
|
if not cfg.backward: |
|
block = block.eval() |
|
block = block.to(device=device, dtype=dtype, memory_format=torch.channels_last) |
|
|
|
if cfg.random_weight: |
|
for param in block.parameters(): |
|
nn.init.trunc_normal_(param, std=0.001) |
|
|
|
if cfg.profile_macs: |
|
macs = profile_macs(block, x) |
|
print(f"macs: {macs}") |
|
|
|
if cfg.export_model: |
|
export_dtype = get_dtype_from_str(cfg.export_dtype) |
|
export_device = torch.device(cfg.export_device) |
|
assert cfg.export_path != "" |
|
export_onnx( |
|
block.to(device=export_device, dtype=export_dtype), |
|
(1, cfg.input_size**2, cfg.num_channels), |
|
cfg.export_path, |
|
cfg.opset, |
|
export_dtype, |
|
export_device, |
|
) |
|
if cfg.test_correctness: |
|
ref_block = ( |
|
LiteMLA(cfg.num_channels, cfg.num_channels, dim=cfg.num_channels // cfg.num_heads, eps=1e-8) |
|
.eval() |
|
.to(device=device, memory_format=torch.channels_last) |
|
) |
|
block.load_state_dict(ref_block.state_dict()) |
|
correct = True |
|
for i in range(10): |
|
ref_x = torch.randn( |
|
cfg.batch_size, cfg.input_size**2, cfg.num_channels, device=device, requires_grad=cfg.backward |
|
) |
|
x = ref_x.clone().detach().to(dtype=dtype).requires_grad_(cfg.backward) |
|
with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast): |
|
output = block(x) |
|
ref_output_dict = simulate_litemla( |
|
ref_x, |
|
ref_block.qkv.weight, |
|
ref_block.proj.weight, |
|
ref_block.proj.bias, |
|
ref_block.in_dim // ref_block.dim, |
|
ref_block.dim, |
|
ref_block.eps, |
|
cfg.backward, |
|
) |
|
ref_output = ref_output_dict["y"] |
|
if cfg.backward: |
|
dy = 0.1 * torch.randn_like(output) |
|
output.backward(dy) |
|
ref_output.backward(dy.float()) |
|
|
|
ref_output_1 = ref_block(ref_x) |
|
assert torch.allclose(ref_output, ref_output_1) |
|
output_float = output.float() |
|
if not torch.allclose(output_float, ref_output): |
|
correct = False |
|
max_error_pos = (output_float - ref_output).abs().view(-1).argmax() |
|
print(f"comparing forward results") |
|
print( |
|
f"max error: {(output_float - ref_output).abs().max()}, mean error: {(output_float - ref_output).abs().mean()}" |
|
) |
|
print(f"max error pos: {ref_output.view(-1)[max_error_pos]} {output_float.view(-1)[max_error_pos]}") |
|
if cfg.backward: |
|
for name, grad, ref_grad in [ |
|
("proj_weight", block.proj.weight.grad, ref_block.proj.weight.grad), |
|
("proj_bias", block.proj.bias.grad, ref_block.proj.bias.grad), |
|
("qkv_weight", block.qkv.weight.grad, ref_block.qkv.weight.grad), |
|
("x", x.grad, ref_x.grad), |
|
]: |
|
print(f"comparing {name}") |
|
grad_float = grad.float() |
|
max_error_pos = (grad_float - ref_grad).abs().view(-1).argmax() |
|
print( |
|
f"max error: {(grad_float - ref_grad).abs().max()}, mean error: {(grad_float - ref_grad).abs().mean()}" |
|
) |
|
print(f"max error pos: {ref_grad.view(-1)[max_error_pos]} {grad_float.view(-1)[max_error_pos]}") |
|
|
|
if correct: |
|
print("correct!") |
|
elif cfg.use_cuda_graph: |
|
x = torch.randn( |
|
cfg.batch_size, |
|
cfg.input_size**2, |
|
cfg.num_channels, |
|
device=device, |
|
dtype=dtype, |
|
requires_grad=cfg.backward, |
|
) |
|
grad_y = 0.1 * torch.randn_like(x) |
|
|
|
s = torch.cuda.Stream() |
|
s.wait_stream(torch.cuda.current_stream()) |
|
with torch.cuda.stream(s): |
|
for i in range(cfg.warmup_iterations): |
|
with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast): |
|
y = block(x) |
|
if cfg.backward: |
|
y.backward(grad_y) |
|
torch.cuda.current_stream().wait_stream(s) |
|
|
|
g = torch.cuda.CUDAGraph() |
|
|
|
|
|
with torch.cuda.graph(g): |
|
with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast): |
|
y = block(x) |
|
if cfg.backward: |
|
y.backward(grad_y) |
|
|
|
torch.cuda.synchronize() |
|
start_time = time.time() |
|
for i in range(cfg.iterations): |
|
g.replay() |
|
torch.cuda.synchronize() |
|
end_time = time.time() |
|
print(f"using cuda graph:") |
|
print(f"each step takes {(end_time-start_time)*1000/cfg.iterations:.2f} ms") |
|
print(f"max memory allocated: {torch.cuda.max_memory_allocated()/1024**3:.4f} GB") |
|
else: |
|
x = torch.randn( |
|
cfg.batch_size, |
|
cfg.input_size**2, |
|
cfg.num_channels, |
|
device=device, |
|
dtype=dtype, |
|
requires_grad=cfg.backward, |
|
) |
|
grad_y = 0.1 * torch.randn_like(x) |
|
for i in range(cfg.warmup_iterations): |
|
|
|
with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast): |
|
y = block(x) |
|
if cfg.backward: |
|
y.backward(grad_y) |
|
|
|
torch.cuda.synchronize() |
|
start_time = time.time() |
|
for i in range(cfg.iterations): |
|
with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast): |
|
y = block(x) |
|
if cfg.backward: |
|
y.backward(grad_y) |
|
torch.cuda.synchronize() |
|
end_time = time.time() |
|
print(f"each step takes {(end_time - start_time) * 1000 / cfg.iterations:.2f} ms") |
|
|
|
print(f"max memory allocated: {torch.cuda.max_memory_allocated() / 1024 ** 3:.4f} GB") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
""" |
|
# 64x64 fp16 |
|
python -m develop_triton_litemla attn_type=LiteMLA test_correctness=True |
|
each step takes 10.81 ms |
|
max memory allocated: 2.2984 GB |
|
|
|
python -m develop_triton_litemla attn_type=TritonLiteMLA test_correctness=True |
|
each step takes 4.70 ms |
|
max memory allocated: 1.6480 GB |
|
""" |
|
|