Commit
·
c1736a8
1
Parent(s):
362ef00
feat: adapter masking wip
Browse filesSigned-off-by: Meow <[email protected]>
- embedding.py +21 -4
- modeling_lora.py +10 -2
- modeling_xlm_roberta.py +7 -6
- xlm_padding.py +9 -1
embedding.py
CHANGED
|
@@ -40,7 +40,7 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
| 40 |
if self.type_vocab_size > 0:
|
| 41 |
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
|
| 42 |
|
| 43 |
-
def forward(self, input_ids, position_ids=None, token_type_ids=None, task_type=None):
|
| 44 |
"""
|
| 45 |
input_ids: (batch, seqlen)
|
| 46 |
position_ids: (batch, seqlen)
|
|
@@ -55,9 +55,25 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
| 55 |
emb1 = self.word_embeddings(tensor1, task_type=task_type[0])
|
| 56 |
emb2 = self.word_embeddings(tensor2, task_type=task_type[1])
|
| 57 |
embeddings = torch.cat((emb1, emb2), dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
else:
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
if self.max_position_embeddings > 0:
|
| 63 |
if position_ids is None:
|
|
@@ -79,7 +95,8 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
| 79 |
emb2 = emb2 + token_type_embs2
|
| 80 |
embeddings = torch.cat((emb1, emb2), dim=0)
|
| 81 |
else:
|
| 82 |
-
|
|
|
|
| 83 |
token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
|
| 84 |
embeddings = embeddings + token_type_embeddings
|
| 85 |
return embeddings
|
|
|
|
| 40 |
if self.type_vocab_size > 0:
|
| 41 |
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
|
| 42 |
|
| 43 |
+
def forward(self, input_ids, position_ids=None, token_type_ids=None, task_type=None, adapter_mask=None):
|
| 44 |
"""
|
| 45 |
input_ids: (batch, seqlen)
|
| 46 |
position_ids: (batch, seqlen)
|
|
|
|
| 55 |
emb1 = self.word_embeddings(tensor1, task_type=task_type[0])
|
| 56 |
emb2 = self.word_embeddings(tensor2, task_type=task_type[1])
|
| 57 |
embeddings = torch.cat((emb1, emb2), dim=0)
|
| 58 |
+
|
| 59 |
+
unique_tasks = torch.unique(adapter_mask).tolist()
|
| 60 |
+
torch_dtype = next(self.word_embeddings.parameters()).dtype
|
| 61 |
+
embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim, dtype=torch_dtype).to(input_ids.device)
|
| 62 |
+
for task in unique_tasks:
|
| 63 |
+
indices = (adapter_mask == task).nonzero(as_tuple=True)[0]
|
| 64 |
+
inp = input_ids[indices]
|
| 65 |
+
lora_kwargs = {'task_type': task} if task is not None else {}
|
| 66 |
+
emb = self.word_embeddings(inp, **lora_kwargs)
|
| 67 |
+
embeddings[indices] = emb
|
| 68 |
+
|
| 69 |
+
exit(0)
|
| 70 |
else:
|
| 71 |
+
unique_task = torch.unique(adapter_mask)[0]
|
| 72 |
+
task1_indices = (adapter_mask == unique_task).nonzero(as_tuple=True)[0]
|
| 73 |
+
input1 = input_ids[task1_indices]
|
| 74 |
+
lora_kwargs = {'task_type': unique_task} if unique_task is not None else {}
|
| 75 |
+
embeddings = self.word_embeddings(input1, **lora_kwargs)
|
| 76 |
+
|
| 77 |
|
| 78 |
if self.max_position_embeddings > 0:
|
| 79 |
if position_ids is None:
|
|
|
|
| 95 |
emb2 = emb2 + token_type_embs2
|
| 96 |
embeddings = torch.cat((emb1, emb2), dim=0)
|
| 97 |
else:
|
| 98 |
+
unique_task = torch.unique(adapter_mask)[0]
|
| 99 |
+
lora_kwargs = {'task_type': unique_task} if unique_task is not None else {}
|
| 100 |
token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
|
| 101 |
embeddings = embeddings + token_type_embeddings
|
| 102 |
return embeddings
|
modeling_lora.py
CHANGED
|
@@ -177,7 +177,11 @@ class LoRAParametrization(nn.Module):
|
|
| 177 |
)
|
| 178 |
|
| 179 |
def new_forward(self, input, task_type, residual=False):
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
if task_idx is not None:
|
| 182 |
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
|
| 183 |
else:
|
|
@@ -205,7 +209,11 @@ class LoRAParametrization(nn.Module):
|
|
| 205 |
)
|
| 206 |
|
| 207 |
def new_forward(self, input, task_type):
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
if task_idx is not None:
|
| 210 |
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
|
| 211 |
else:
|
|
|
|
| 177 |
)
|
| 178 |
|
| 179 |
def new_forward(self, input, task_type, residual=False):
|
| 180 |
+
if isinstance(task_type, str):
|
| 181 |
+
task_idx = adaptation_map[task_type] if task_type else None
|
| 182 |
+
else:
|
| 183 |
+
task_idx = task_type
|
| 184 |
+
|
| 185 |
if task_idx is not None:
|
| 186 |
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
|
| 187 |
else:
|
|
|
|
| 209 |
)
|
| 210 |
|
| 211 |
def new_forward(self, input, task_type):
|
| 212 |
+
if isinstance(task_type, str):
|
| 213 |
+
task_idx = adaptation_map[task_type] if task_type else None
|
| 214 |
+
else:
|
| 215 |
+
task_idx = task_type
|
| 216 |
+
|
| 217 |
if task_idx is not None:
|
| 218 |
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
|
| 219 |
else:
|
modeling_xlm_roberta.py
CHANGED
|
@@ -204,7 +204,7 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 204 |
def gradient_checkpointing(self, value):
|
| 205 |
self._grad_checkpointing = value
|
| 206 |
|
| 207 |
-
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, task_type=None):
|
| 208 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
| 209 |
This means that we only compute the last layer output for these tokens.
|
| 210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
|
@@ -230,10 +230,10 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 230 |
hidden_states = hidden_states[subset_mask]
|
| 231 |
else:
|
| 232 |
batch, seqlen = hidden_states.shape[:2]
|
| 233 |
-
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
| 234 |
-
hidden_states, key_padding_mask
|
| 235 |
)
|
| 236 |
-
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "task_type": task_type}
|
| 237 |
if subset_mask is None:
|
| 238 |
for layer in self.layers:
|
| 239 |
if self._grad_checkpointing:
|
|
@@ -649,6 +649,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 649 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
| 650 |
"""
|
| 651 |
task_type = kwargs.pop('task_type', None)
|
|
|
|
| 652 |
if kwargs:
|
| 653 |
for key, value in kwargs.items():
|
| 654 |
if value is not None:
|
|
@@ -662,7 +663,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 662 |
)
|
| 663 |
|
| 664 |
hidden_states = self.embeddings(
|
| 665 |
-
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, task_type=task_type
|
| 666 |
)
|
| 667 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
| 668 |
# BERT puts embedding LayerNorm before embedding dropout.
|
|
@@ -686,7 +687,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 686 |
subset_mask = None
|
| 687 |
|
| 688 |
sequence_output = self.encoder(
|
| 689 |
-
hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, task_type=task_type
|
| 690 |
)
|
| 691 |
|
| 692 |
if masked_tokens_mask is None:
|
|
|
|
| 204 |
def gradient_checkpointing(self, value):
|
| 205 |
self._grad_checkpointing = value
|
| 206 |
|
| 207 |
+
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, task_type=None, adapter_mask=None):
|
| 208 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
| 209 |
This means that we only compute the last layer output for these tokens.
|
| 210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
|
|
|
| 230 |
hidden_states = hidden_states[subset_mask]
|
| 231 |
else:
|
| 232 |
batch, seqlen = hidden_states.shape[:2]
|
| 233 |
+
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
|
| 234 |
+
hidden_states, key_padding_mask, adapter_mask
|
| 235 |
)
|
| 236 |
+
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "task_type": task_type, "cu_adapter_mask": cu_adapter_mask}
|
| 237 |
if subset_mask is None:
|
| 238 |
for layer in self.layers:
|
| 239 |
if self._grad_checkpointing:
|
|
|
|
| 649 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
| 650 |
"""
|
| 651 |
task_type = kwargs.pop('task_type', None)
|
| 652 |
+
adapter_mask = kwargs.pop('adapter_mask', None)
|
| 653 |
if kwargs:
|
| 654 |
for key, value in kwargs.items():
|
| 655 |
if value is not None:
|
|
|
|
| 663 |
)
|
| 664 |
|
| 665 |
hidden_states = self.embeddings(
|
| 666 |
+
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, task_type=task_type, adapter_mask=adapter_mask
|
| 667 |
)
|
| 668 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
| 669 |
# BERT puts embedding LayerNorm before embedding dropout.
|
|
|
|
| 687 |
subset_mask = None
|
| 688 |
|
| 689 |
sequence_output = self.encoder(
|
| 690 |
+
hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, task_type=task_type, adapter_mask=adapter_mask
|
| 691 |
)
|
| 692 |
|
| 693 |
if masked_tokens_mask is None:
|
xlm_padding.py
CHANGED
|
@@ -98,7 +98,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
|
|
| 98 |
index_first_axis_residual = IndexFirstAxisResidual.apply
|
| 99 |
|
| 100 |
|
| 101 |
-
def unpad_input(hidden_states, attention_mask):
|
| 102 |
"""
|
| 103 |
Arguments:
|
| 104 |
hidden_states: (batch, seqlen, ...)
|
|
@@ -113,6 +113,13 @@ def unpad_input(hidden_states, attention_mask):
|
|
| 113 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 114 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 115 |
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
| 117 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
| 118 |
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
|
@@ -123,6 +130,7 @@ def unpad_input(hidden_states, attention_mask):
|
|
| 123 |
indices,
|
| 124 |
cu_seqlens,
|
| 125 |
max_seqlen_in_batch,
|
|
|
|
| 126 |
)
|
| 127 |
|
| 128 |
|
|
|
|
| 98 |
index_first_axis_residual = IndexFirstAxisResidual.apply
|
| 99 |
|
| 100 |
|
| 101 |
+
def unpad_input(hidden_states, attention_mask, adapter_mask):
|
| 102 |
"""
|
| 103 |
Arguments:
|
| 104 |
hidden_states: (batch, seqlen, ...)
|
|
|
|
| 113 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 114 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 115 |
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
| 116 |
+
|
| 117 |
+
cu_adapter_mask = torch.empty(cu_seqlens[-1], dtype=torch.int32)
|
| 118 |
+
for i in range(len(adapter_mask)):
|
| 119 |
+
start_idx = cu_seqlens[i]
|
| 120 |
+
end_idx = cu_seqlens[i + 1]
|
| 121 |
+
cu_adapter_mask[start_idx:end_idx] = adapter_mask[i]
|
| 122 |
+
|
| 123 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
| 124 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
| 125 |
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
|
|
|
| 130 |
indices,
|
| 131 |
cu_seqlens,
|
| 132 |
max_seqlen_in_batch,
|
| 133 |
+
cu_adapter_mask,
|
| 134 |
)
|
| 135 |
|
| 136 |
|