Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import torch | |
| import torch.distributed as dist | |
| from ..modules.attention import flash_attention | |
| from .util import all_to_all | |
| def distributed_attention( | |
| q, | |
| k, | |
| v, | |
| seq_lens, | |
| window_size=(-1, -1), | |
| ): | |
| """ | |
| Performs distributed attention based on DeepSpeed Ulysses attention mechanism. | |
| please refer to https://arxiv.org/pdf/2309.14509 | |
| Args: | |
| q: [B, Lq // p, Nq, C1]. | |
| k: [B, Lk // p, Nk, C1]. | |
| v: [B, Lk // p, Nk, C2]. Nq must be divisible by Nk. | |
| seq_lens: [B], length of each sequence in batch | |
| window_size: (left right). If not (-1, -1), apply sliding window local attention. | |
| """ | |
| if not dist.is_initialized(): | |
| raise ValueError("distributed group should be initialized.") | |
| b = q.shape[0] | |
| # gather q/k/v sequence | |
| q = all_to_all(q, scatter_dim=2, gather_dim=1) | |
| k = all_to_all(k, scatter_dim=2, gather_dim=1) | |
| v = all_to_all(v, scatter_dim=2, gather_dim=1) | |
| # apply attention | |
| x = flash_attention( | |
| q, | |
| k, | |
| v, | |
| k_lens=seq_lens, | |
| window_size=window_size, | |
| ) | |
| # scatter q/k/v sequence | |
| x = all_to_all(x, scatter_dim=1, gather_dim=2) | |
| return x | |