File size: 7,827 Bytes
dd7417a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from typing import Any, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from flash_attn import flash_attn_varlen_func
import deepspeed.comm as dist
dist = None
from utils import (
except (ModuleNotFoundError, ImportError):
# 从 utils 获取seq parallel设置,import不成功默认为不开启
get_sequence_parallel_group = lambda : None
get_sequence_parallel_size = lambda : 1
get_sequence_parallel_rank = lambda : 0
def single_all_to_all(input, scatter_idx, gather_idx, group):
seq_world_size = dist.get_world_size(group)
inp_shape = list(input.shape)
inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
if scatter_idx < 2:
input_t = input.reshape(
[seq_world_size, inp_shape[scatter_idx]] + \
inp_shape[scatter_idx + 1:]
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
input_t = input.reshape(
[-1, seq_world_size, inp_shape[scatter_idx]] + \
inp_shape[scatter_idx + 1:]
).transpose(0, 1).contiguous()
output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)
# if scattering the seq-dim, transpose the heads back to the original dimension
# [sp_size, seq_len//sp_size, batch_size, head_num // sp_size, head_dim] -->
# [seq_len//sp_size,batch_size, sp_size, head_num // sp_size, head_dim]
if scatter_idx < 2:
output = output.transpose(0, 1).transpose(1, 2).contiguous()
return output.reshape(
inp_shape[: gather_idx] + \
[inp_shape[gather_idx] * seq_world_size,] + \
inp_shape[gather_idx + 1:]).contiguous()
class _SeqAllToAll(torch.autograd.Function):
def forward(ctx: Any, group: 'dist.ProcessGroup', input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
return single_all_to_all(input, scatter_idx, gather_idx, group)
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
return (None, _SeqAllToAll.apply(, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
# import from
# but fix some bugs for 符合训练的维度设置
class DistributedAttention(nn.Module):
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
def __init__(
local_attention: nn.Module,
sequence_process_group: 'dist.ProcessGroup',
scatter_idx: int = 2,
gather_idx: int = 0,
) -> None:
super(DistributedAttention, self).__init__()
self.local_attn = local_attention
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
def pad_attention_head(self, query: Tensor, key: Tensor, value: Tensor):
# 将输入的head 维度pad到sp_size的倍数
sp_size = torch.distributed.get_world_size(self.spg)
pad_size = (sp_size - query.size(1) % sp_size) % sp_size
if pad_size > 0:
# [bs, num_head, seq_len, head_dim] -> [bs, num_head+pad_size, seq_len, head_dim]
query = torch.nn.functional.pad(query, (0,0,0,0,0,pad_size), value = 0.01)
key = torch.nn.functional.pad(key, (0,0,0,0,0,pad_size), value = 0.01)
value = torch.nn.functional.pad(value, (0,0,0,0,0,pad_size),value=0.0)
return query, key, value
def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
""" forward
query (Tensor): query input to the layer [batch_size, num_head, seq_len, head_dim]
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args
* output (Tensor): context output
# TODO Merge three alltoall calls into one
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
# [batch_size,num_head,seq_len, head_dim ]trans to [seq_len,batch_size,num_head,head_dim]
origin_num_head = query.size(1)
query, key, value = self.pad_attention_head(query,key,value)
query = query.transpose(1,2).transpose(0,1)
key = key.transpose(1,2).transpose(0,1)
value = value.transpose(1,2).transpose(0,1)
#in shape : e.g., [s/p,bs,h,head_dim]
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)
context_layer = context_layer.transpose(0,1).contiguous()
# [seq_len, batch_size, num_head, head_dim]
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
return output.transpose(0,1)[:,:,:origin_num_head,:]
class LocalAttention(nn.Module):
def __init__(self, hidden_size, num_heads, head_dim):
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
def forward(self, q, k, v, *args, use_flash=True, **kwargs):
# input q,k,v [batch_size, num_head, seq_len, head_dim]
# output [batch_size, seq_len, num_head, head_dim]
if use_flash:
q_len, num_heads = q.shape[2], q.shape[1]
q = q.transpose(1,2).reshape(-1, num_heads, self.head_dim)
k = k.transpose(1,2).reshape(-1, num_heads, self.head_dim)
v = v.transpose(1,2).reshape(-1, num_heads, self.head_dim)
return flash_attn_varlen_func(q,k,v,*args, **kwargs).reshape(-1,q_len, num_heads, self.head_dim)
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
attn_output = F.scaled_dot_product_attention(
q,k,v, *args, **kwargs)
attn_output = attn_output.transpose(1, 2)
return attn_output
def create_attention_layer(hidden_size, num_heads, head_dim):
if get_sequence_parallel_group() is None:
return LocalAttention(hidden_size, num_heads, head_dim)
return DistributedAttention(
local_attention=LocalAttention(hidden_size, num_heads, head_dim),
def get_sequence_parallel_chunk(tensor, dim=1, shift=0):
assert tensor.size(dim) % get_sequence_parallel_size() == 0
original_size = tensor.size(dim)
if shift:
tensor = tensor.split([shift, tensor.size(dim) - shift], dim=dim)[1]
if get_sequence_parallel_group() is None:
return tensor
chunk_size = original_size // get_sequence_parallel_size()
return tensor.split(chunk_size, dim=dim)[get_sequence_parallel_rank()]