Commit
·
8b64fa8
1
Parent(s):
f2e0e62
chore: remove parallelmha
Browse filesSigned-off-by: jupyterjazz <[email protected]>
mha.py
CHANGED
|
@@ -7,8 +7,6 @@ import torch
|
|
| 7 |
import torch.nn as nn
|
| 8 |
from einops import rearrange, repeat
|
| 9 |
|
| 10 |
-
from flash_attn.utils.distributed import get_dim_for_local_rank
|
| 11 |
-
|
| 12 |
try:
|
| 13 |
from flash_attn import (
|
| 14 |
flash_attn_kvpacked_func,
|
|
@@ -706,316 +704,3 @@ class MHA(nn.Module):
|
|
| 706 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 707 |
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
| 708 |
return out if not self.return_residual else (out, x)
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
class ParallelMHA(nn.Module):
|
| 712 |
-
"""Multi-head self-attention and cross-attention"""
|
| 713 |
-
|
| 714 |
-
def __init__(
|
| 715 |
-
self,
|
| 716 |
-
embed_dim,
|
| 717 |
-
num_heads,
|
| 718 |
-
process_group,
|
| 719 |
-
num_heads_kv=None,
|
| 720 |
-
qkv_proj_bias=True,
|
| 721 |
-
out_proj_bias=True,
|
| 722 |
-
dropout=0.0,
|
| 723 |
-
softmax_scale=None,
|
| 724 |
-
causal=False,
|
| 725 |
-
layer_idx=None,
|
| 726 |
-
rotary_emb_dim=0,
|
| 727 |
-
rotary_emb_base=10000.0,
|
| 728 |
-
rotary_emb_scale_base=None,
|
| 729 |
-
rotary_emb_interleaved=False,
|
| 730 |
-
use_alibi=False,
|
| 731 |
-
use_flash_attn=False,
|
| 732 |
-
checkpointing=False,
|
| 733 |
-
sequence_parallel=True,
|
| 734 |
-
device=None,
|
| 735 |
-
dtype=None,
|
| 736 |
-
) -> None:
|
| 737 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 738 |
-
super().__init__()
|
| 739 |
-
self.embed_dim = embed_dim
|
| 740 |
-
self.causal = causal
|
| 741 |
-
self.layer_idx = layer_idx
|
| 742 |
-
self.rotary_emb_dim = rotary_emb_dim
|
| 743 |
-
self.use_flash_attn = use_flash_attn
|
| 744 |
-
self.checkpointing = checkpointing
|
| 745 |
-
self.process_group = process_group
|
| 746 |
-
self.world_size = process_group.size()
|
| 747 |
-
self.local_rank = torch.distributed.get_rank(process_group)
|
| 748 |
-
|
| 749 |
-
self.num_heads = num_heads
|
| 750 |
-
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 751 |
-
|
| 752 |
-
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
| 753 |
-
assert (
|
| 754 |
-
self.num_heads % self.num_heads_kv == 0
|
| 755 |
-
), "num_heads must be divisible by num_heads_kv"
|
| 756 |
-
|
| 757 |
-
self.num_heads_per_rank = get_dim_for_local_rank(
|
| 758 |
-
self.num_heads, self.world_size, self.local_rank
|
| 759 |
-
)
|
| 760 |
-
self.num_heads_kv_per_rank = get_dim_for_local_rank(
|
| 761 |
-
self.num_heads_kv, self.world_size, self.local_rank
|
| 762 |
-
)
|
| 763 |
-
self.head_dim = self.embed_dim // num_heads
|
| 764 |
-
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
| 765 |
-
|
| 766 |
-
if use_alibi:
|
| 767 |
-
assert use_flash_attn, "ALiBi code path requires flash_attn"
|
| 768 |
-
num_heads_local = math.ceil(self.num_heads / self.world_size)
|
| 769 |
-
alibi_slopes = torch.tensor(
|
| 770 |
-
get_alibi_slopes(num_heads)[
|
| 771 |
-
self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
|
| 772 |
-
],
|
| 773 |
-
device=device,
|
| 774 |
-
)
|
| 775 |
-
else:
|
| 776 |
-
alibi_slopes = None
|
| 777 |
-
|
| 778 |
-
if self.rotary_emb_dim > 0:
|
| 779 |
-
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
| 780 |
-
self.rotary_emb = RotaryEmbedding(
|
| 781 |
-
self.rotary_emb_dim,
|
| 782 |
-
base=rotary_emb_base,
|
| 783 |
-
scale_base=rotary_emb_scale_base,
|
| 784 |
-
interleaved=rotary_emb_interleaved,
|
| 785 |
-
device=device,
|
| 786 |
-
)
|
| 787 |
-
|
| 788 |
-
if ColumnParallelLinear is None or RowParallelLinear is None:
|
| 789 |
-
raise ImportError("fused_dense is not installed")
|
| 790 |
-
self.Wqkv = ColumnParallelLinear(
|
| 791 |
-
embed_dim,
|
| 792 |
-
qkv_dim,
|
| 793 |
-
process_group,
|
| 794 |
-
bias=qkv_proj_bias,
|
| 795 |
-
sequence_parallel=sequence_parallel,
|
| 796 |
-
multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
|
| 797 |
-
**factory_kwargs,
|
| 798 |
-
)
|
| 799 |
-
inner_attn_cls = (
|
| 800 |
-
partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
|
| 801 |
-
if use_flash_attn
|
| 802 |
-
else SelfAttention
|
| 803 |
-
)
|
| 804 |
-
inner_cross_attn_cls = (
|
| 805 |
-
partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
|
| 806 |
-
if use_flash_attn
|
| 807 |
-
else CrossAttention
|
| 808 |
-
)
|
| 809 |
-
self.inner_attn = inner_attn_cls(
|
| 810 |
-
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
| 811 |
-
)
|
| 812 |
-
self.inner_cross_attn = inner_cross_attn_cls(
|
| 813 |
-
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
| 814 |
-
)
|
| 815 |
-
self.out_proj = RowParallelLinear(
|
| 816 |
-
embed_dim,
|
| 817 |
-
embed_dim,
|
| 818 |
-
process_group,
|
| 819 |
-
bias=out_proj_bias,
|
| 820 |
-
sequence_parallel=sequence_parallel,
|
| 821 |
-
multiple_of=self.head_dim,
|
| 822 |
-
**factory_kwargs,
|
| 823 |
-
)
|
| 824 |
-
|
| 825 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
| 826 |
-
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
| 827 |
-
device = self.out_proj.weight.device
|
| 828 |
-
return torch.empty(
|
| 829 |
-
batch_size,
|
| 830 |
-
max_seqlen,
|
| 831 |
-
2,
|
| 832 |
-
self.num_heads_kv_per_rank,
|
| 833 |
-
self.head_dim,
|
| 834 |
-
dtype=dtype,
|
| 835 |
-
device=device,
|
| 836 |
-
)
|
| 837 |
-
|
| 838 |
-
def _update_kv_cache(self, kv, inference_params):
|
| 839 |
-
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
| 840 |
-
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
| 841 |
-
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
| 842 |
-
|
| 843 |
-
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
| 844 |
-
"""
|
| 845 |
-
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
| 846 |
-
q: (batch_size, seqlen_q, nheads, head_dim)
|
| 847 |
-
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
| 848 |
-
"""
|
| 849 |
-
assert inference_params is not None and inference_params.seqlen_offset > 0
|
| 850 |
-
assert self.use_flash_attn
|
| 851 |
-
if self.rotary_emb_dim > 0:
|
| 852 |
-
assert self.rotary_emb.scale is None, "This code path does not support xPos"
|
| 853 |
-
self.rotary_emb._update_cos_sin_cache(
|
| 854 |
-
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
| 855 |
-
)
|
| 856 |
-
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
| 857 |
-
else:
|
| 858 |
-
rotary_cos, rotary_sin = None, None
|
| 859 |
-
batch = q.shape[0]
|
| 860 |
-
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 861 |
-
cache_seqlens = (
|
| 862 |
-
inference_params.lengths_per_sample[:batch]
|
| 863 |
-
if inference_params.lengths_per_sample is not None
|
| 864 |
-
else inference_params.seqlen_offset
|
| 865 |
-
)
|
| 866 |
-
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
| 867 |
-
context = flash_attn_with_kvcache(
|
| 868 |
-
q,
|
| 869 |
-
kv_cache[:, :, 0],
|
| 870 |
-
kv_cache[:, :, 1],
|
| 871 |
-
kv[:, :, 0],
|
| 872 |
-
kv[:, :, 1],
|
| 873 |
-
rotary_cos=rotary_cos,
|
| 874 |
-
rotary_sin=rotary_sin,
|
| 875 |
-
cache_seqlens=cache_seqlens,
|
| 876 |
-
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 877 |
-
causal=self.inner_cross_attn.causal,
|
| 878 |
-
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
| 879 |
-
alibi_slopes=alibi_slopes,
|
| 880 |
-
)
|
| 881 |
-
return context
|
| 882 |
-
|
| 883 |
-
def _update_kvcache_attention(self, q, kv, inference_params):
|
| 884 |
-
"""Write kv to inference_params, then do attention"""
|
| 885 |
-
if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
|
| 886 |
-
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
| 887 |
-
kv = self._update_kv_cache(kv, inference_params)
|
| 888 |
-
return self.inner_cross_attn(q, kv)
|
| 889 |
-
else:
|
| 890 |
-
batch = q.shape[0]
|
| 891 |
-
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 892 |
-
cache_seqlens = (
|
| 893 |
-
inference_params.lengths_per_sample[:batch]
|
| 894 |
-
if inference_params.lengths_per_sample is not None
|
| 895 |
-
else inference_params.seqlen_offset
|
| 896 |
-
)
|
| 897 |
-
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
| 898 |
-
context = flash_attn_with_kvcache(
|
| 899 |
-
q,
|
| 900 |
-
kv_cache[:, :, 0],
|
| 901 |
-
kv_cache[:, :, 1],
|
| 902 |
-
kv[:, :, 0],
|
| 903 |
-
kv[:, :, 1],
|
| 904 |
-
cache_seqlens=cache_seqlens,
|
| 905 |
-
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 906 |
-
causal=self.inner_cross_attn.causal,
|
| 907 |
-
alibi_slopes=alibi_slopes,
|
| 908 |
-
)
|
| 909 |
-
return context
|
| 910 |
-
|
| 911 |
-
def forward(
|
| 912 |
-
self, x, seqlen=None, inference_params=None, cu_seqlens=None, max_seqlen=None, **kwargs
|
| 913 |
-
):
|
| 914 |
-
"""
|
| 915 |
-
Arguments:
|
| 916 |
-
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None and cu_seqlens=None.
|
| 917 |
-
(seqlen, hidden_dim) if cu_seqlens not None, seqlen equal cu_seqlens[-1].
|
| 918 |
-
If seqlen is not None and cu_seqlens=None, x is (batch * seqlen, hidden_dim). This is so that when we
|
| 919 |
-
split x during sequence parallel, we split the batch * seqlen dimension
|
| 920 |
-
(in case batch is small).
|
| 921 |
-
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
| 922 |
-
of the sequences in the batch, used to index into x. Only applicable when using
|
| 923 |
-
FlashAttention.
|
| 924 |
-
max_seqlen: int. Maximum sequence length in the batch.
|
| 925 |
-
"""
|
| 926 |
-
if cu_seqlens is not None:
|
| 927 |
-
assert max_seqlen is not None
|
| 928 |
-
assert seqlen is None
|
| 929 |
-
assert self.use_flash_attn
|
| 930 |
-
if inference_params is not None:
|
| 931 |
-
assert cu_seqlens is None and max_seqlen is None
|
| 932 |
-
qkv = self.Wqkv(x)
|
| 933 |
-
if seqlen is not None:
|
| 934 |
-
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
|
| 935 |
-
kwargs = (
|
| 936 |
-
{"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
|
| 937 |
-
if self.use_flash_attn
|
| 938 |
-
else kwargs
|
| 939 |
-
)
|
| 940 |
-
seqlen_offset = (
|
| 941 |
-
0
|
| 942 |
-
if inference_params is None
|
| 943 |
-
else (
|
| 944 |
-
inference_params.lengths_per_sample
|
| 945 |
-
if inference_params.lengths_per_sample is not None
|
| 946 |
-
else inference_params.seqlen_offset
|
| 947 |
-
)
|
| 948 |
-
)
|
| 949 |
-
rotary_max_seqlen = (
|
| 950 |
-
inference_params.max_sequence_len if inference_params is not None else max_seqlen
|
| 951 |
-
)
|
| 952 |
-
if self.num_heads_kv == self.num_heads:
|
| 953 |
-
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
| 954 |
-
if (
|
| 955 |
-
inference_params is None
|
| 956 |
-
or inference_params.seqlen_offset == 0
|
| 957 |
-
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
| 958 |
-
or not self.use_flash_attn
|
| 959 |
-
):
|
| 960 |
-
if self.rotary_emb_dim > 0:
|
| 961 |
-
qkv = self.rotary_emb(
|
| 962 |
-
qkv,
|
| 963 |
-
seqlen_offset=seqlen_offset,
|
| 964 |
-
cu_seqlens=cu_seqlens,
|
| 965 |
-
max_seqlen=rotary_max_seqlen,
|
| 966 |
-
)
|
| 967 |
-
if inference_params is None:
|
| 968 |
-
if not self.checkpointing:
|
| 969 |
-
context = self.inner_attn(qkv, **kwargs)
|
| 970 |
-
else:
|
| 971 |
-
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
| 972 |
-
else:
|
| 973 |
-
context = self._update_kvcache_attention(
|
| 974 |
-
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
| 975 |
-
)
|
| 976 |
-
else:
|
| 977 |
-
context = self._apply_rotary_update_kvcache_attention(
|
| 978 |
-
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
| 979 |
-
)
|
| 980 |
-
else:
|
| 981 |
-
q = rearrange(
|
| 982 |
-
qkv[..., : self.num_heads_per_rank * self.head_dim],
|
| 983 |
-
"... (h d) -> ... h d",
|
| 984 |
-
d=self.head_dim,
|
| 985 |
-
)
|
| 986 |
-
kv = rearrange(
|
| 987 |
-
qkv[..., self.num_heads_per_rank * self.head_dim :],
|
| 988 |
-
"... (two hkv d) -> ... two hkv d",
|
| 989 |
-
two=2,
|
| 990 |
-
d=self.head_dim,
|
| 991 |
-
)
|
| 992 |
-
if (
|
| 993 |
-
inference_params is None
|
| 994 |
-
or inference_params.seqlen_offset == 0
|
| 995 |
-
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
| 996 |
-
or not self.use_flash_attn
|
| 997 |
-
):
|
| 998 |
-
if self.rotary_emb_dim > 0:
|
| 999 |
-
q, kv = self.rotary_emb(
|
| 1000 |
-
q,
|
| 1001 |
-
kv,
|
| 1002 |
-
seqlen_offset=seqlen_offset,
|
| 1003 |
-
cu_seqlens=cu_seqlens,
|
| 1004 |
-
max_seqlen=rotary_max_seqlen,
|
| 1005 |
-
)
|
| 1006 |
-
if inference_params is None:
|
| 1007 |
-
if not self.checkpointing:
|
| 1008 |
-
context = self.inner_cross_attn(q, kv, **kwargs)
|
| 1009 |
-
else:
|
| 1010 |
-
context = torch.utils.checkpoint.checkpoint(
|
| 1011 |
-
self.inner_cross_attn, q, kv, **kwargs
|
| 1012 |
-
)
|
| 1013 |
-
else:
|
| 1014 |
-
context = self._update_kvcache_attention(q, kv, inference_params)
|
| 1015 |
-
else:
|
| 1016 |
-
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 1017 |
-
context = rearrange(context, "... h d -> ... (h d)")
|
| 1018 |
-
if seqlen is not None:
|
| 1019 |
-
context = rearrange(context, "b s d -> (b s) d")
|
| 1020 |
-
out = self.out_proj(context)
|
| 1021 |
-
return out
|
|
|
|
| 7 |
import torch.nn as nn
|
| 8 |
from einops import rearrange, repeat
|
| 9 |
|
|
|
|
|
|
|
| 10 |
try:
|
| 11 |
from flash_attn import (
|
| 12 |
flash_attn_kvpacked_func,
|
|
|
|
| 704 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 705 |
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
| 706 |
return out if not self.return_residual else (out, x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|