fix bert_padding
Browse files- __pycache__/attention.cpython-311.pyc +0 -0
- __pycache__/layers.cpython-311.pyc +0 -0
- attention.py +13 -13
- layers.py +5 -5
__pycache__/attention.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/attention.cpython-311.pyc and b/__pycache__/attention.cpython-311.pyc differ
|
|
|
__pycache__/layers.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/layers.cpython-311.pyc and b/__pycache__/layers.cpython-311.pyc differ
|
|
|
attention.py
CHANGED
|
@@ -24,7 +24,7 @@ import sys
|
|
| 24 |
import os
|
| 25 |
# Add src folder root to path to allow us to use relative imports regardless of what directory the script is run from
|
| 26 |
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
| 27 |
-
import
|
| 28 |
from .configuration_bert import FlexBertConfig, maybe_add_padding
|
| 29 |
from .normalization import get_norm_layer
|
| 30 |
from .initialization import ModuleType, init_weights
|
|
@@ -161,7 +161,7 @@ class BertAlibiUnpadSelfAttention(nn.Module):
|
|
| 161 |
alibi_slopes=slopes,
|
| 162 |
)
|
| 163 |
else:
|
| 164 |
-
qkv =
|
| 165 |
unpad_bs, *_ = qkv.shape
|
| 166 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size)
|
| 167 |
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
|
|
@@ -174,7 +174,7 @@ class BertAlibiUnpadSelfAttention(nn.Module):
|
|
| 174 |
attention_probs = self.dropout(attention_probs)
|
| 175 |
attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d
|
| 176 |
|
| 177 |
-
attention =
|
| 178 |
|
| 179 |
return attention.view(bs, dim)
|
| 180 |
|
|
@@ -240,8 +240,8 @@ class BertAlibiUnpadAttention(nn.Module):
|
|
| 240 |
self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes)
|
| 241 |
if subset_idx is not None:
|
| 242 |
return self.output(
|
| 243 |
-
|
| 244 |
-
|
| 245 |
)
|
| 246 |
else:
|
| 247 |
return self.output(self_output, input_tensor)
|
|
@@ -415,7 +415,7 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
|
|
| 415 |
)
|
| 416 |
attn = attn.view(bs, dim)
|
| 417 |
else:
|
| 418 |
-
qkv =
|
| 419 |
unpad_bs, seqlen, _ = qkv.shape
|
| 420 |
|
| 421 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
|
|
@@ -430,7 +430,7 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
|
|
| 430 |
else None,
|
| 431 |
)
|
| 432 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
| 433 |
-
attn =
|
| 434 |
|
| 435 |
return self.out_drop(self.Wo(attn))
|
| 436 |
|
|
@@ -565,7 +565,7 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
|
|
| 565 |
)
|
| 566 |
attn = attn.view(bs, dim)
|
| 567 |
else:
|
| 568 |
-
qkv =
|
| 569 |
unpad_bs, seqlen, _ = qkv.shape
|
| 570 |
|
| 571 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
|
|
@@ -580,7 +580,7 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
|
|
| 580 |
else None,
|
| 581 |
)
|
| 582 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
| 583 |
-
attn =
|
| 584 |
|
| 585 |
return self.out_drop(self.Wo(attn.view(bs, dim)))
|
| 586 |
|
|
@@ -913,7 +913,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
| 913 |
)
|
| 914 |
attn = attn.view(bs, dim)
|
| 915 |
else:
|
| 916 |
-
qkv =
|
| 917 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
| 918 |
) # batch, max_seqlen, thd
|
| 919 |
unpad_bs, seqlen, *_ = qkv.shape
|
|
@@ -929,7 +929,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
| 929 |
else None,
|
| 930 |
)
|
| 931 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
| 932 |
-
attn =
|
| 933 |
|
| 934 |
return self.out_drop(self.Wo(attn))
|
| 935 |
|
|
@@ -1244,7 +1244,7 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
|
|
| 1244 |
)
|
| 1245 |
attn = attn.view(bs, dim)
|
| 1246 |
else:
|
| 1247 |
-
qkv =
|
| 1248 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
| 1249 |
) # batch, max_seqlen, thd
|
| 1250 |
unpad_bs, seqlen, *_ = qkv.shape
|
|
@@ -1260,7 +1260,7 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
|
|
| 1260 |
else None,
|
| 1261 |
)
|
| 1262 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
| 1263 |
-
attn =
|
| 1264 |
|
| 1265 |
return self.out_drop(self.Wo(attn))
|
| 1266 |
|
|
|
|
| 24 |
import os
|
| 25 |
# Add src folder root to path to allow us to use relative imports regardless of what directory the script is run from
|
| 26 |
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
| 27 |
+
from .bert_padding import pad_input, unpad_input_only, index_first_axis
|
| 28 |
from .configuration_bert import FlexBertConfig, maybe_add_padding
|
| 29 |
from .normalization import get_norm_layer
|
| 30 |
from .initialization import ModuleType, init_weights
|
|
|
|
| 161 |
alibi_slopes=slopes,
|
| 162 |
)
|
| 163 |
else:
|
| 164 |
+
qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
| 165 |
unpad_bs, *_ = qkv.shape
|
| 166 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size)
|
| 167 |
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
|
|
|
|
| 174 |
attention_probs = self.dropout(attention_probs)
|
| 175 |
attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d
|
| 176 |
|
| 177 |
+
attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
|
| 178 |
|
| 179 |
return attention.view(bs, dim)
|
| 180 |
|
|
|
|
| 240 |
self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes)
|
| 241 |
if subset_idx is not None:
|
| 242 |
return self.output(
|
| 243 |
+
index_first_axis(self_output, subset_idx),
|
| 244 |
+
index_first_axis(input_tensor, subset_idx),
|
| 245 |
)
|
| 246 |
else:
|
| 247 |
return self.output(self_output, input_tensor)
|
|
|
|
| 415 |
)
|
| 416 |
attn = attn.view(bs, dim)
|
| 417 |
else:
|
| 418 |
+
qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
| 419 |
unpad_bs, seqlen, _ = qkv.shape
|
| 420 |
|
| 421 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
|
|
|
|
| 430 |
else None,
|
| 431 |
)
|
| 432 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
| 433 |
+
attn = unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
|
| 434 |
|
| 435 |
return self.out_drop(self.Wo(attn))
|
| 436 |
|
|
|
|
| 565 |
)
|
| 566 |
attn = attn.view(bs, dim)
|
| 567 |
else:
|
| 568 |
+
qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
| 569 |
unpad_bs, seqlen, _ = qkv.shape
|
| 570 |
|
| 571 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
|
|
|
|
| 580 |
else None,
|
| 581 |
)
|
| 582 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
| 583 |
+
attn = unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
|
| 584 |
|
| 585 |
return self.out_drop(self.Wo(attn.view(bs, dim)))
|
| 586 |
|
|
|
|
| 913 |
)
|
| 914 |
attn = attn.view(bs, dim)
|
| 915 |
else:
|
| 916 |
+
qkv = pad_input(
|
| 917 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
| 918 |
) # batch, max_seqlen, thd
|
| 919 |
unpad_bs, seqlen, *_ = qkv.shape
|
|
|
|
| 929 |
else None,
|
| 930 |
)
|
| 931 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
| 932 |
+
attn = unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
|
| 933 |
|
| 934 |
return self.out_drop(self.Wo(attn))
|
| 935 |
|
|
|
|
| 1244 |
)
|
| 1245 |
attn = attn.view(bs, dim)
|
| 1246 |
else:
|
| 1247 |
+
qkv = pad_input(
|
| 1248 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
| 1249 |
) # batch, max_seqlen, thd
|
| 1250 |
unpad_bs, seqlen, *_ = qkv.shape
|
|
|
|
| 1260 |
else None,
|
| 1261 |
)
|
| 1262 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
| 1263 |
+
attn = unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
|
| 1264 |
|
| 1265 |
return self.out_drop(self.Wo(attn))
|
| 1266 |
|
layers.py
CHANGED
|
@@ -20,7 +20,7 @@ from typing import Optional, Union, List
|
|
| 20 |
import torch
|
| 21 |
import torch.nn as nn
|
| 22 |
|
| 23 |
-
import
|
| 24 |
|
| 25 |
from .activation import get_act_fn
|
| 26 |
from .attention import FlexBertAttentionBase, BertAlibiUnpadAttention, get_attention_layer
|
|
@@ -155,7 +155,7 @@ class BertAlibiEncoder(nn.Module):
|
|
| 155 |
# and ntokens_unpad is total number of non-padded tokens.
|
| 156 |
# Then unpadding performs the following compression of the inputs:
|
| 157 |
# hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
|
| 158 |
-
hidden_states, indices, cu_seqlens, _ =
|
| 159 |
|
| 160 |
# Add alibi matrix to extended_attention_mask
|
| 161 |
if self._current_alibi_size < seqlen:
|
|
@@ -190,7 +190,7 @@ class BertAlibiEncoder(nn.Module):
|
|
| 190 |
# and ntokens_unpad is total number of non-padded tokens.
|
| 191 |
# Then padding performs the following de-compression:
|
| 192 |
# hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
|
| 193 |
-
hidden_states =
|
| 194 |
else:
|
| 195 |
for i in range(len(self.layer) - 1):
|
| 196 |
layer_module = self.layer[i]
|
|
@@ -636,7 +636,7 @@ class FlexBertUnpadEncoder(FlexBertEncoderBase):
|
|
| 636 |
if indices is None and cu_seqlens is None and max_seqlen is None:
|
| 637 |
attention_mask_bool = attention_mask.bool()
|
| 638 |
batch, seqlen = hidden_states.shape[:2]
|
| 639 |
-
hidden_states, indices, cu_seqlens, max_seqlen =
|
| 640 |
hidden_states, attention_mask_bool
|
| 641 |
)
|
| 642 |
|
|
@@ -649,7 +649,7 @@ class FlexBertUnpadEncoder(FlexBertEncoderBase):
|
|
| 649 |
attn_mask=attention_mask,
|
| 650 |
)
|
| 651 |
|
| 652 |
-
return
|
| 653 |
else:
|
| 654 |
for layer_module in self.layers:
|
| 655 |
hidden_states = layer_module(
|
|
|
|
| 20 |
import torch
|
| 21 |
import torch.nn as nn
|
| 22 |
|
| 23 |
+
from .bert_padding import unpad_input, pad_input
|
| 24 |
|
| 25 |
from .activation import get_act_fn
|
| 26 |
from .attention import FlexBertAttentionBase, BertAlibiUnpadAttention, get_attention_layer
|
|
|
|
| 155 |
# and ntokens_unpad is total number of non-padded tokens.
|
| 156 |
# Then unpadding performs the following compression of the inputs:
|
| 157 |
# hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
|
| 158 |
+
hidden_states, indices, cu_seqlens, _ = unpad_input(hidden_states, attention_mask_bool)
|
| 159 |
|
| 160 |
# Add alibi matrix to extended_attention_mask
|
| 161 |
if self._current_alibi_size < seqlen:
|
|
|
|
| 190 |
# and ntokens_unpad is total number of non-padded tokens.
|
| 191 |
# Then padding performs the following de-compression:
|
| 192 |
# hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
|
| 193 |
+
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
| 194 |
else:
|
| 195 |
for i in range(len(self.layer) - 1):
|
| 196 |
layer_module = self.layer[i]
|
|
|
|
| 636 |
if indices is None and cu_seqlens is None and max_seqlen is None:
|
| 637 |
attention_mask_bool = attention_mask.bool()
|
| 638 |
batch, seqlen = hidden_states.shape[:2]
|
| 639 |
+
hidden_states, indices, cu_seqlens, max_seqlen = unpad_input(
|
| 640 |
hidden_states, attention_mask_bool
|
| 641 |
)
|
| 642 |
|
|
|
|
| 649 |
attn_mask=attention_mask,
|
| 650 |
)
|
| 651 |
|
| 652 |
+
return pad_input(hidden_states, indices, batch, seqlen)
|
| 653 |
else:
|
| 654 |
for layer_module in self.layers:
|
| 655 |
hidden_states = layer_module(
|