# Copyright 2024 MIT Han Lab # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 import time from dataclasses import dataclass import ipdb import torch from modules.flash_attn import FlashAttention from modules.lite_mla import LiteMLA from modules.triton_lite_mla import TritonLiteMLA from modules.triton_lite_mla_fwd import TritonLiteMLAFwd from modules.utils.dtype import get_dtype_from_str from modules.utils.export_onnx import export_onnx from omegaconf import OmegaConf from torch import nn from torch.nn import functional as F from torchprofile import profile_macs @dataclass class DevelopTritonLiteMLAConfig: batch_size: int = 16 input_size: int = 1024 // 8 // 2 num_channels: int = 1152 num_heads: int = 36 attn_type: str = "LiteMLA" device: str = "cuda" dtype: str = "fp16" profile_macs: bool = False test_correctness: bool = False warmup_iterations: int = 50 iterations: int = 1000 random_weight: bool = True backward: bool = False autocast: bool = False use_cuda_graph: bool = False export_model: bool = False opset: int = 17 export_path: str = "" export_dtype: str = "fp32" export_device: str = "cuda" def simulate_litemla( x: torch.Tensor, qkv_weight: torch.Tensor, proj_weight: torch.Tensor, proj_bias: torch.Tensor, num_heads: int, head_dim: int, eps: float, backward: bool, ): B, N, C = x.shape qkv = F.linear(x, qkv_weight).reshape(B, N, 3, C).permute(0, 2, 3, 1) q, k, v = qkv.unbind(1) # B, 3, C, N --> B, C, N q = q.reshape(B, C // head_dim, head_dim, N) # b, h, h_d, N k = k.reshape(B, C // head_dim, head_dim, N).transpose(-1, -2) # b, h, N, h_d v = v.reshape(B, C // head_dim, head_dim, N) # b, h, h_d, N q = F.relu(q) # B, h, h_d, N k = F.relu(k) q, k, v = q.float(), k.float(), v.float() if backward: k.retain_grad() v.retain_grad() q.retain_grad() v_pad = F.pad(v, (0, 0, 0, 1), mode="constant", value=1) vk = torch.matmul(v_pad, k) if backward: vk.retain_grad() vk_q = torch.matmul(vk, q) vk_q_numerator, vk_q_denominator = vk_q[:, :, :-1], vk_q[:, :, -1:] if backward: vk_q_numerator.retain_grad() vk_q_denominator.retain_grad() vk_q_divide = (vk_q_numerator / (vk_q_denominator + eps)).to(x.dtype) proj_input = vk_q_divide.view(B, C, N).permute(0, 2, 1) # B, N, C if backward: proj_input.retain_grad() y = F.linear(proj_input, proj_weight, proj_bias) output_dict = { "q": q, "k": k, "v": v, "vk": vk, "proj_input": proj_input, "vk_q_numerator": vk_q_numerator, "vk_q_denominator": vk_q_denominator, "vk_q_divide": vk_q_divide, "y": y, } return output_dict def main(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True LiteMLA.fp32_attention = True torch.cuda.manual_seed(0) torch.manual_seed(0) cfg = OmegaConf.structured(DevelopTritonLiteMLAConfig) cli_cfg = OmegaConf.from_cli() cfg = OmegaConf.merge(cfg, OmegaConf.masked_copy(cli_cfg, cfg.keys())) cfg: DevelopTritonLiteMLAConfig = OmegaConf.to_object(cfg) torch.set_grad_enabled(cfg.backward) device = torch.device("cuda") if cfg.autocast: dtype = torch.float32 autocast_dtype = get_dtype_from_str(cfg.dtype) else: dtype = get_dtype_from_str(cfg.dtype) autocast_dtype = None if cfg.attn_type == "LiteMLA": block = LiteMLA(cfg.num_channels, cfg.num_channels, dim=cfg.num_channels // cfg.num_heads, eps=1e-8) elif cfg.attn_type == "TritonLiteMLA": block = TritonLiteMLA(cfg.num_channels, cfg.num_heads, eps=1e-8) elif cfg.attn_type == "TritonLiteMLAFwd": block = TritonLiteMLAFwd(cfg.num_channels, cfg.num_heads, eps=1e-8) elif cfg.attn_type == "FlashAttention": block = FlashAttention(cfg.num_channels, cfg.num_heads) else: raise NotImplementedError if not cfg.backward: block = block.eval() block = block.to(device=device, dtype=dtype, memory_format=torch.channels_last) if cfg.random_weight: for param in block.parameters(): nn.init.trunc_normal_(param, std=0.001) if cfg.profile_macs: macs = profile_macs(block, x) print(f"macs: {macs}") if cfg.export_model: export_dtype = get_dtype_from_str(cfg.export_dtype) export_device = torch.device(cfg.export_device) assert cfg.export_path != "" export_onnx( block.to(device=export_device, dtype=export_dtype), (1, cfg.input_size**2, cfg.num_channels), cfg.export_path, cfg.opset, export_dtype, export_device, ) if cfg.test_correctness: ref_block = ( LiteMLA(cfg.num_channels, cfg.num_channels, dim=cfg.num_channels // cfg.num_heads, eps=1e-8) .eval() .to(device=device, memory_format=torch.channels_last) ) block.load_state_dict(ref_block.state_dict()) correct = True for i in range(10): ref_x = torch.randn( cfg.batch_size, cfg.input_size**2, cfg.num_channels, device=device, requires_grad=cfg.backward ) x = ref_x.clone().detach().to(dtype=dtype).requires_grad_(cfg.backward) with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast): output = block(x) ref_output_dict = simulate_litemla( ref_x, ref_block.qkv.weight, ref_block.proj.weight, ref_block.proj.bias, ref_block.in_dim // ref_block.dim, ref_block.dim, ref_block.eps, cfg.backward, ) ref_output = ref_output_dict["y"] if cfg.backward: dy = 0.1 * torch.randn_like(output) output.backward(dy) ref_output.backward(dy.float()) # ipdb.set_trace() ref_output_1 = ref_block(ref_x) assert torch.allclose(ref_output, ref_output_1) output_float = output.float() if not torch.allclose(output_float, ref_output): correct = False max_error_pos = (output_float - ref_output).abs().view(-1).argmax() print(f"comparing forward results") print( f"max error: {(output_float - ref_output).abs().max()}, mean error: {(output_float - ref_output).abs().mean()}" ) print(f"max error pos: {ref_output.view(-1)[max_error_pos]} {output_float.view(-1)[max_error_pos]}") if cfg.backward: for name, grad, ref_grad in [ ("proj_weight", block.proj.weight.grad, ref_block.proj.weight.grad), ("proj_bias", block.proj.bias.grad, ref_block.proj.bias.grad), ("qkv_weight", block.qkv.weight.grad, ref_block.qkv.weight.grad), ("x", x.grad, ref_x.grad), ]: print(f"comparing {name}") grad_float = grad.float() max_error_pos = (grad_float - ref_grad).abs().view(-1).argmax() print( f"max error: {(grad_float - ref_grad).abs().max()}, mean error: {(grad_float - ref_grad).abs().mean()}" ) print(f"max error pos: {ref_grad.view(-1)[max_error_pos]} {grad_float.view(-1)[max_error_pos]}") # ipdb.set_trace() if correct: print("correct!") elif cfg.use_cuda_graph: x = torch.randn( cfg.batch_size, cfg.input_size**2, cfg.num_channels, device=device, dtype=dtype, requires_grad=cfg.backward, ) grad_y = 0.1 * torch.randn_like(x) s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for i in range(cfg.warmup_iterations): with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast): y = block(x) if cfg.backward: y.backward(grad_y) torch.cuda.current_stream().wait_stream(s) g = torch.cuda.CUDAGraph() # Sets grads to None before capture, so backward() will create # .grad attributes with allocations from the graph's private pool with torch.cuda.graph(g): with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast): y = block(x) if cfg.backward: y.backward(grad_y) torch.cuda.synchronize() start_time = time.time() for i in range(cfg.iterations): g.replay() torch.cuda.synchronize() end_time = time.time() print(f"using cuda graph:") print(f"each step takes {(end_time-start_time)*1000/cfg.iterations:.2f} ms") print(f"max memory allocated: {torch.cuda.max_memory_allocated()/1024**3:.4f} GB") else: x = torch.randn( cfg.batch_size, cfg.input_size**2, cfg.num_channels, device=device, dtype=dtype, requires_grad=cfg.backward, ) grad_y = 0.1 * torch.randn_like(x) for i in range(cfg.warmup_iterations): # ipdb.set_trace() with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast): y = block(x) if cfg.backward: y.backward(grad_y) torch.cuda.synchronize() start_time = time.time() for i in range(cfg.iterations): with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast): y = block(x) if cfg.backward: y.backward(grad_y) torch.cuda.synchronize() end_time = time.time() print(f"each step takes {(end_time - start_time) * 1000 / cfg.iterations:.2f} ms") # ipdb.set_trace() print(f"max memory allocated: {torch.cuda.max_memory_allocated() / 1024 ** 3:.4f} GB") # x = torch.randn(cfg.batch_size*2, (cfg.input_size*2)**2, cfg.num_channels, device=device, dtype=dtype, requires_grad=cfg.backward) # grad_y = 0.1*torch.randn_like(x) # with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast): # y = block(x) # if cfg.backward: # y.backward(grad_y) if __name__ == "__main__": main() """ # 64x64 fp16 python -m develop_triton_litemla attn_type=LiteMLA test_correctness=True each step takes 10.81 ms max memory allocated: 2.2984 GB python -m develop_triton_litemla attn_type=TritonLiteMLA test_correctness=True each step takes 4.70 ms max memory allocated: 1.6480 GB """