import torch
from triton_flash_atn import _attention

# Define dimensions
batch_size = 2
num_heads = 4
seq_len = 128
head_dim = 64

# Create random input tensors for Q, K, V
q = torch.randn(batch_size, num_heads, seq_len, head_dim,
                dtype=torch.float16, device='cuda')
k = torch.randn(batch_size, num_heads, seq_len, head_dim,
                dtype=torch.float16, device='cuda')
v = torch.randn(batch_size, num_heads, seq_len, head_dim,
                dtype=torch.float16, device='cuda')

# Define whether the attention is causal and the scaling factor
causal = False
sm_scale = 1.0 / (head_dim ** 0.5)

# Apply flash attention
attention = _attention.apply
output = attention(q, k, v, causal, sm_scale)

print(output)