NohTow commited on
Commit
170de2e
·
1 Parent(s): a243956

New version

Browse files
Files changed (3) hide show
  1. loss.py +0 -30
  2. model.py +0 -1684
  3. modeling_flexbert.py +14 -14
loss.py DELETED
@@ -1,30 +0,0 @@
1
- # Copyright 2024 **AUTHORS_TODO**
2
- # License: Apache-2.0
3
-
4
- import inspect
5
- import torch.nn as nn
6
- from .configuration_bert import FlexBertConfig
7
-
8
- try:
9
- from flash_attn.losses.cross_entropy import CrossEntropyLoss
10
- except ImportError:
11
- CrossEntropyLoss = None
12
-
13
- LOSS2CLS = {
14
- "cross_entropy": nn.CrossEntropyLoss,
15
- "binary_cross_entropy": nn.BCEWithLogitsLoss,
16
- "mean_squared_error": nn.MSELoss,
17
- }
18
-
19
- if CrossEntropyLoss is not None:
20
- LOSS2CLS["fa_cross_entropy"] = CrossEntropyLoss
21
-
22
-
23
- def get_loss_fn(config: FlexBertConfig) -> nn.Module:
24
- try:
25
- loss_class = LOSS2CLS[config.loss_function]
26
- signature = inspect.signature(loss_class)
27
- loss_kwargs = {k: v for k, v in config.loss_kwargs.items() if k in signature.parameters}
28
- return loss_class(**loss_kwargs)
29
- except KeyError:
30
- raise ValueError(f"Invalid loss function type: {config.loss_function}, must be one of {LOSS2CLS.keys()}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py DELETED
@@ -1,1684 +0,0 @@
1
- # Copyright 2024 **AUTHORS_TODO**
2
- # License: Apache-2.0
3
-
4
- # RMSNorm Implementation: Copyright Meta (from their Llama RMSNorm implementation)
5
- # License: LLAMA 2 COMMUNITY LICENSE AGREEMENT
6
-
7
- # Copyright 2022 Jonas Geiping
8
- # License: MIT
9
-
10
- # Copyright 2022 MosaicML Examples authors
11
- # SPDX-License-Identifier: Apache-2.0
12
-
13
- # Copyright 2023 MosaicML Examples authors
14
- # SPDX-License-Identifier: Apache-2.0
15
-
16
- # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
17
- # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
18
- # Copyright (c) 2023, Tri Dao.
19
-
20
- """Implements Mosaic BERT, with an eye towards the Hugging Face API.
21
-
22
- Mosaic BERT improves performance over Hugging Face BERT through the following:
23
-
24
- 1. ALiBi. This architectural change removes positional embeddings and instead encodes positional
25
- information through attention biases based on query-key position distance. It improves the effectiveness
26
- of training with shorter sequence lengths by enabling extrapolation to longer sequences.
27
-
28
- 2. Gated Linear Units (GLU). This architectural change replaces the FFN component of the BERT layer
29
- to improve overall expressiveness, providing better convergence properties.
30
-
31
- 3. Flash Attention. The MosaicBERT's self-attention layer makes use of Flash Attention, which dramatically
32
- improves the speed of self-attention. Our implementation utilizes a bleeding edge implementation that
33
- supports attention biases, which allows us to use Flash Attention with ALiBi.
34
-
35
- 4. Unpadding. Padding is often used to simplify batching across sequences of different lengths. Standard BERT
36
- implementations waste computation on padded tokens. MosaicBERT internally unpads to reduce unnecessary computation
37
- and improve speed. It does this without changing how the user interfaces with the model, thereby
38
- preserving the simple API of standard implementations.
39
-
40
-
41
- Currently, MosaicBERT is available for masked language modeling :class:`BertForMaskedLM` and sequence
42
- classification :class:`BertForSequenceClassification`. We aim to expand this catalogue in future releases.
43
-
44
- See :file:`./mosaic_bert.py` for utilities to simplify working with MosaicBERT in Composer, and for example usage
45
- of the core Mosaic BERT classes.
46
- """
47
-
48
- import logging
49
- import os
50
- import sys
51
- import warnings
52
- from dataclasses import dataclass
53
- from typing import List, Optional, Tuple, Union
54
-
55
- # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
56
- sys.path.append(os.path.dirname(os.path.realpath(__file__)))
57
-
58
- import torch
59
- import torch.nn as nn
60
- from einops import rearrange
61
- from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
62
- from transformers.modeling_outputs import (
63
- MaskedLMOutput,
64
- ModelOutput,
65
- MultipleChoiceModelOutput,
66
- SequenceClassifierOutput,
67
- )
68
- from transformers.models.bert.modeling_bert import BertPreTrainedModel
69
-
70
- from bert_padding import index_put_first_axis
71
-
72
- from src.bert_layers.activation import get_act_fn
73
- from src.bert_layers.attention import (
74
- FlexBertPaddedAttention,
75
- FlexBertPaddedParallelAttention,
76
- FlexBertPaddedRopeAttention,
77
- FlexBertPaddedRopeParallelAttention,
78
- FlexBertUnpadAttention,
79
- FlexBertUnpadParallelAttention,
80
- FlexBertUnpadRopeAttention,
81
- FlexBertUnpadRopeParallelAttention,
82
- )
83
- from src.bert_layers.configuration_bert import FlexBertConfig
84
- from src.bert_layers.embeddings import (
85
- BertAlibiEmbeddings,
86
- FlexBertAbsoluteEmbeddings,
87
- FlexBertCompiledSansPositionEmbeddings,
88
- FlexBertSansPositionEmbeddings,
89
- get_embedding_layer,
90
- )
91
- from src.bert_layers.initialization import (
92
- ModuleType,
93
- TileLinear,
94
- TileMode,
95
- init_weights,
96
- tile_embedding,
97
- tile_linear,
98
- tile_norm,
99
- )
100
- from src.bert_layers.layers import (
101
- BertAlibiEncoder,
102
- BertPooler,
103
- BertPredictionHeadTransform,
104
- FlexBertCompileUnpadPreNormLayer,
105
- FlexBertPaddedEncoder,
106
- FlexBertPaddedParallelPreNormLayer,
107
- FlexBertPaddedPostNormLayer,
108
- FlexBertPaddedPreNormLayer,
109
- FlexBertUnpadEncoder,
110
- FlexBertUnpadParallelPreNormLayer,
111
- FlexBertUnpadPostNormLayer,
112
- FlexBertUnpadPreNormLayer,
113
- get_encoder_layer,
114
- )
115
- from src.bert_layers.loss import get_loss_fn
116
- from src.bert_layers.mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
117
- from src.bert_layers.normalization import get_norm_layer
118
- from src.bert_layers.padding import pad_input, unpad_input
119
-
120
- logger = logging.getLogger(__name__)
121
-
122
-
123
- def _count_parameters(model: nn.Module, trainable: bool = True) -> int:
124
- if trainable:
125
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
126
- else:
127
- return sum(p.numel() for p in model.parameters())
128
-
129
-
130
- class BertModel(BertPreTrainedModel):
131
- """Overall BERT model.
132
-
133
- Args:
134
- config: a BertConfig class instance with the configuration to build a new model
135
-
136
- Inputs:
137
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
138
- with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
139
- `extract_features.py`, `run_classifier.py` and `run_squad.py`)
140
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
141
- types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
142
- a `sentence B` token (see BERT paper for more details).
143
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
144
- selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
145
- input sequence length in the current batch. It's the mask that we typically use for attention when
146
- a batch has varying length sentences.
147
- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
148
-
149
- Outputs: Tuple of (encoded_layers, pooled_output)
150
- `encoded_layers`: controlled by `output_all_encoded_layers` argument:
151
- - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
152
- of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
153
- encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
154
- - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
155
- to the last attention block of shape [batch_size, sequence_length, hidden_size],
156
- `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
157
- classifier pretrained on top of the hidden state associated to the first character of the
158
- input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
159
-
160
- Example usage:
161
- ```python
162
- # Already been converted into WordPiece token ids
163
- input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
164
- input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
165
- token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
166
- config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
167
- num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
168
- model = BertModel(config=config)
169
- all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
170
- ```
171
- """
172
-
173
- def __init__(
174
- self,
175
- config,
176
- add_pooling_layer: bool = True,
177
- ):
178
- super(BertModel, self).__init__(config)
179
- self.embeddings = BertAlibiEmbeddings(config)
180
- self.encoder = BertAlibiEncoder(config)
181
- self.pooler = BertPooler(config) if add_pooling_layer else None
182
- self.post_init()
183
-
184
- def get_input_embeddings(self):
185
- return self.embeddings.word_embeddings
186
-
187
- def set_input_embeddings(self, value):
188
- self.embeddings.word_embeddings = value
189
-
190
- def forward(
191
- self,
192
- input_ids: torch.Tensor,
193
- token_type_ids: Optional[torch.Tensor] = None,
194
- attention_mask: Optional[torch.Tensor] = None,
195
- position_ids: Optional[torch.Tensor] = None,
196
- output_all_encoded_layers: Optional[bool] = False,
197
- masked_tokens_mask: Optional[torch.Tensor] = None,
198
- **kwargs,
199
- ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
200
- if attention_mask is None:
201
- attention_mask = torch.ones_like(input_ids)
202
- if token_type_ids is None:
203
- token_type_ids = torch.zeros_like(input_ids)
204
-
205
- embedding_output = self.embeddings(input_ids, token_type_ids, position_ids)
206
-
207
- subset_mask = []
208
- first_col_mask = []
209
-
210
- if masked_tokens_mask is None:
211
- subset_mask = None
212
- else:
213
- first_col_mask = torch.zeros_like(masked_tokens_mask)
214
- first_col_mask[:, 0] = True
215
- subset_mask = masked_tokens_mask | first_col_mask
216
-
217
- encoder_outputs = self.encoder(
218
- embedding_output,
219
- attention_mask,
220
- output_all_encoded_layers=output_all_encoded_layers,
221
- subset_mask=subset_mask,
222
- )
223
-
224
- if masked_tokens_mask is None:
225
- sequence_output = encoder_outputs[-1]
226
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
227
- else:
228
- # TD [2022-03-01]: the indexing here is very tricky.
229
- attention_mask_bool = attention_mask.bool()
230
- subset_idx = subset_mask[attention_mask_bool] # type: ignore
231
- sequence_output = encoder_outputs[-1][masked_tokens_mask[attention_mask_bool][subset_idx]]
232
- if self.pooler is not None:
233
- pool_input = encoder_outputs[-1][first_col_mask[attention_mask_bool][subset_idx]]
234
- pooled_output = self.pooler(pool_input, pool=False)
235
- else:
236
- pooled_output = None
237
-
238
- if not output_all_encoded_layers:
239
- encoder_outputs = sequence_output
240
-
241
- if self.pooler is not None:
242
- return encoder_outputs, pooled_output
243
-
244
- return encoder_outputs, None
245
-
246
-
247
- ###################
248
- # Bert Heads
249
- ###################
250
- class BertLMPredictionHead(nn.Module):
251
- def __init__(self, config, bert_model_embedding_weights):
252
- super().__init__()
253
- self.transform = BertPredictionHeadTransform(config)
254
- # The output weights are the same as the input embeddings, but there is
255
- # an output-only bias for each token.
256
- self.decoder = nn.Linear(bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0))
257
- self.decoder.weight = bert_model_embedding_weights
258
-
259
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
260
- hidden_states = self.transform(hidden_states)
261
- hidden_states = self.decoder(hidden_states)
262
- return hidden_states
263
-
264
-
265
- class BertOnlyMLMHead(nn.Module):
266
- def __init__(self, config, bert_model_embedding_weights):
267
- super().__init__()
268
- self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
269
-
270
- def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
271
- prediction_scores = self.predictions(sequence_output)
272
- return prediction_scores
273
-
274
-
275
- class BertOnlyNSPHead(nn.Module):
276
- def __init__(self, config):
277
- super().__init__()
278
- self.seq_relationship = nn.Linear(config.hidden_size, 2)
279
-
280
- def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
281
- seq_relationship_score = self.seq_relationship(pooled_output)
282
- return seq_relationship_score
283
-
284
-
285
- #####################
286
- # Various Bert models
287
- #####################
288
-
289
-
290
- class BertForPreTraining(BertPreTrainedModel):
291
- # TBD: Coming in Future Commit
292
- pass
293
-
294
-
295
- class BertLMHeadModel(BertPreTrainedModel):
296
- # TBD: Coming in Future Commit
297
- pass
298
-
299
-
300
- class BertForMaskedLM(BertPreTrainedModel):
301
- def __init__(self, config):
302
- super().__init__(config)
303
-
304
- if config.is_decoder:
305
- warnings.warn(
306
- "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
307
- "bi-directional self-attention."
308
- )
309
-
310
- self.bert = BertModel(config, add_pooling_layer=False)
311
- self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
312
-
313
- # Initialize weights and apply final processing
314
- self.post_init()
315
-
316
- @classmethod
317
- def from_composer(
318
- cls,
319
- pretrained_checkpoint,
320
- state_dict=None,
321
- cache_dir=None,
322
- from_tf=False,
323
- config=None,
324
- *inputs,
325
- **kwargs,
326
- ):
327
- """Load from pre-trained."""
328
- model = cls(config, *inputs, **kwargs)
329
- if from_tf:
330
- raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
331
-
332
- state_dict = torch.load(pretrained_checkpoint)
333
- # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
334
- consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
335
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
336
-
337
- if len(missing_keys) > 0:
338
- logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
339
- if len(unexpected_keys) > 0:
340
- logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
341
-
342
- return model
343
-
344
- def get_output_embeddings(self):
345
- return self.cls.predictions.decoder
346
-
347
- def set_output_embeddings(self, new_embeddings):
348
- self.cls.predictions.decoder = new_embeddings
349
-
350
- def forward(
351
- self,
352
- input_ids: Optional[torch.Tensor] = None,
353
- attention_mask: Optional[torch.Tensor] = None,
354
- token_type_ids: Optional[torch.Tensor] = None,
355
- position_ids: Optional[torch.Tensor] = None,
356
- head_mask: Optional[torch.Tensor] = None,
357
- inputs_embeds: Optional[torch.Tensor] = None,
358
- encoder_hidden_states: Optional[torch.Tensor] = None,
359
- encoder_attention_mask: Optional[torch.Tensor] = None,
360
- labels: Optional[torch.Tensor] = None,
361
- output_attentions: Optional[bool] = None,
362
- output_hidden_states: Optional[bool] = None,
363
- return_dict: Optional[bool] = None,
364
- ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
365
- # labels should be a `torch.LongTensor` of shape
366
- # `(batch_size, sequence_length)`. These are used for computing the
367
- # masked language modeling loss.
368
- #
369
- # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
370
- # `input_ids` docstring) Tokens with indices set to `-100` are ignored
371
- # (masked), the loss is only computed for the tokens with labels in `[0,
372
- # ..., config.vocab_size]`
373
- #
374
- # Prediction scores are only computed for masked tokens and the (bs,
375
- # seqlen) dimensions are flattened
376
- if (input_ids is not None) == (inputs_embeds is not None):
377
- raise ValueError("Must specify either input_ids or input_embeds!")
378
-
379
- if labels is None:
380
- masked_tokens_mask = None
381
- else:
382
- masked_tokens_mask = labels > 0
383
-
384
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
385
-
386
- outputs = self.bert(
387
- input_ids,
388
- attention_mask=attention_mask,
389
- token_type_ids=token_type_ids,
390
- position_ids=position_ids,
391
- head_mask=head_mask,
392
- inputs_embeds=inputs_embeds,
393
- encoder_hidden_states=encoder_hidden_states,
394
- encoder_attention_mask=encoder_attention_mask,
395
- output_attentions=output_attentions,
396
- output_hidden_states=output_hidden_states,
397
- return_dict=return_dict,
398
- masked_tokens_mask=masked_tokens_mask,
399
- )
400
-
401
- sequence_output = outputs[0]
402
- prediction_scores = self.cls(sequence_output)
403
-
404
- loss = None
405
- if labels is not None:
406
- # Compute loss
407
- loss_fct = nn.CrossEntropyLoss()
408
- masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
409
- loss = loss_fct(prediction_scores, labels.flatten()[masked_token_idx])
410
-
411
- assert input_ids is not None, "Coding error; please open an issue"
412
- batch, seqlen = input_ids.shape[:2]
413
- prediction_scores = rearrange(
414
- index_put_first_axis(prediction_scores, masked_token_idx, batch * seqlen),
415
- "(b s) d -> b s d",
416
- b=batch,
417
- )
418
-
419
- if not return_dict:
420
- output = (prediction_scores,) + outputs[2:]
421
- return ((loss,) + output) if loss is not None else output
422
-
423
- return MaskedLMOutput(
424
- loss=loss,
425
- logits=prediction_scores,
426
- hidden_states=None,
427
- attentions=None,
428
- )
429
-
430
- def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs):
431
- input_shape = input_ids.shape
432
- effective_batch_size = input_shape[0]
433
-
434
- # add a dummy token
435
- if self.config.pad_token_id is None:
436
- raise ValueError("The PAD token should be defined for generation")
437
-
438
- attention_mask = torch.cat(
439
- [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
440
- dim=-1,
441
- )
442
- dummy_token = torch.full(
443
- (effective_batch_size, 1),
444
- self.config.pad_token_id,
445
- dtype=torch.long,
446
- device=input_ids.device,
447
- )
448
- input_ids = torch.cat([input_ids, dummy_token], dim=1)
449
-
450
- return {"input_ids": input_ids, "attention_mask": attention_mask}
451
-
452
-
453
- class BertForNextSentencePrediction(BertPreTrainedModel):
454
- # TBD: Push in future commit
455
- pass
456
-
457
-
458
- class BertForSequenceClassification(BertPreTrainedModel):
459
- """Bert Model transformer with a sequence classification/regression head.
460
-
461
- This head is just a linear layer on top of the pooled output. Used for,
462
- e.g., GLUE tasks.
463
- """
464
-
465
- def __init__(self, config):
466
- super().__init__(config)
467
- self.num_labels = config.num_labels
468
- self.config = config
469
-
470
- self.bert = BertModel(config)
471
- classifier_dropout = (
472
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
473
- )
474
- self.dropout = nn.Dropout(classifier_dropout)
475
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
476
-
477
- # Initialize weights and apply final processing
478
- self.post_init()
479
-
480
- @classmethod
481
- def from_composer(
482
- cls,
483
- pretrained_checkpoint,
484
- state_dict=None,
485
- cache_dir=None,
486
- from_tf=False,
487
- config=None,
488
- *inputs,
489
- **kwargs,
490
- ):
491
- """Load from pre-trained."""
492
- model = cls(config, *inputs, **kwargs)
493
- if from_tf:
494
- raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
495
-
496
- state_dict = torch.load(pretrained_checkpoint)
497
- # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
498
- consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
499
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
500
-
501
- if len(missing_keys) > 0:
502
- logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
503
- if len(unexpected_keys) > 0:
504
- logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
505
-
506
- return model
507
-
508
- def forward(
509
- self,
510
- input_ids: Optional[torch.Tensor] = None,
511
- attention_mask: Optional[torch.Tensor] = None,
512
- token_type_ids: Optional[torch.Tensor] = None,
513
- position_ids: Optional[torch.Tensor] = None,
514
- head_mask: Optional[torch.Tensor] = None,
515
- inputs_embeds: Optional[torch.Tensor] = None,
516
- labels: Optional[torch.Tensor] = None,
517
- output_attentions: Optional[bool] = None,
518
- output_hidden_states: Optional[bool] = None,
519
- return_dict: Optional[bool] = None,
520
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
521
- # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
522
- # Labels for computing the sequence classification/regression loss.
523
- # Indices should be in `[0, ..., config.num_labels - 1]`.
524
- # If `config.num_labels == 1` a regression loss is computed
525
- # (mean-square loss). If `config.num_labels > 1` a classification loss
526
- # is computed (cross-entropy).
527
-
528
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
529
-
530
- outputs = self.bert(
531
- input_ids,
532
- attention_mask=attention_mask,
533
- token_type_ids=token_type_ids,
534
- position_ids=position_ids,
535
- head_mask=head_mask,
536
- inputs_embeds=inputs_embeds,
537
- output_attentions=output_attentions,
538
- output_hidden_states=output_hidden_states,
539
- return_dict=return_dict,
540
- )
541
-
542
- pooled_output = outputs[1]
543
-
544
- pooled_output = self.dropout(pooled_output)
545
- logits = self.classifier(pooled_output)
546
-
547
- loss = None
548
- if labels is not None:
549
- # Compute loss
550
- if self.config.problem_type is None:
551
- if self.num_labels == 1:
552
- self.config.problem_type = "regression"
553
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
554
- self.config.problem_type = "single_label_classification"
555
- else:
556
- self.config.problem_type = "multi_label_classification"
557
-
558
- if self.config.problem_type == "regression":
559
- loss_fct = nn.MSELoss()
560
- if self.num_labels == 1:
561
- loss = loss_fct(logits.squeeze(), labels.squeeze())
562
- else:
563
- loss = loss_fct(logits, labels)
564
- elif self.config.problem_type == "single_label_classification":
565
- loss_fct = nn.CrossEntropyLoss()
566
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
567
- elif self.config.problem_type == "multi_label_classification":
568
- loss_fct = nn.BCEWithLogitsLoss()
569
- loss = loss_fct(logits, labels)
570
-
571
- if not return_dict:
572
- output = (logits,) + outputs[2:]
573
- return ((loss,) + output) if loss is not None else output
574
-
575
- return SequenceClassifierOutput(
576
- loss=loss,
577
- logits=logits,
578
- hidden_states=None,
579
- attentions=None,
580
- )
581
-
582
-
583
- class BertForMultipleChoice(BertPreTrainedModel):
584
- """
585
- Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
586
- softmax) e.g. for RocStories/SWAG tasks.
587
- """
588
-
589
- def __init__(self, config):
590
- super().__init__(config)
591
- self.num_labels = config.num_labels
592
- self.config = config
593
-
594
- self.bert = BertModel(config)
595
- classifier_dropout = (
596
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
597
- )
598
- self.dropout = nn.Dropout(classifier_dropout)
599
-
600
- # In multiple choice tasks, all choices are submitted in a batch, and
601
- # we compute a logit for each option independently. The logits are then
602
- # normalized in the forward pass to get a probability distribution over
603
- # the choices.
604
- self.classifier = nn.Linear(config.hidden_size, 1)
605
-
606
- # Initialize weights and apply final processing
607
- self.post_init()
608
-
609
- @classmethod
610
- def from_composer(
611
- cls,
612
- pretrained_checkpoint,
613
- state_dict=None,
614
- cache_dir=None,
615
- from_tf=False,
616
- config=None,
617
- *inputs,
618
- **kwargs,
619
- ):
620
- """Load from pre-trained."""
621
- model = cls(config, *inputs, **kwargs)
622
- if from_tf:
623
- raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
624
-
625
- state_dict = torch.load(pretrained_checkpoint)
626
- # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
627
- consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
628
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
629
-
630
- if len(missing_keys) > 0:
631
- logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
632
- if len(unexpected_keys) > 0:
633
- logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
634
-
635
- return model
636
-
637
- def forward(
638
- self,
639
- input_ids: Optional[torch.Tensor] = None,
640
- attention_mask: Optional[torch.Tensor] = None,
641
- token_type_ids: Optional[torch.Tensor] = None,
642
- position_ids: Optional[torch.Tensor] = None,
643
- head_mask: Optional[torch.Tensor] = None,
644
- inputs_embeds: Optional[torch.Tensor] = None,
645
- labels: Optional[torch.Tensor] = None,
646
- output_attentions: Optional[bool] = None,
647
- output_hidden_states: Optional[bool] = None,
648
- return_dict: Optional[bool] = None,
649
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
650
- r"""
651
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
652
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
653
- num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
654
- `input_ids` above)
655
- """
656
-
657
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
658
- num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
659
-
660
- input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
661
- attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
662
- token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
663
- position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
664
- inputs_embeds = (
665
- inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
666
- if inputs_embeds is not None
667
- else None
668
- )
669
-
670
- outputs = self.bert(
671
- input_ids,
672
- attention_mask=attention_mask,
673
- token_type_ids=token_type_ids,
674
- position_ids=position_ids,
675
- head_mask=head_mask,
676
- inputs_embeds=inputs_embeds,
677
- output_attentions=output_attentions,
678
- output_hidden_states=output_hidden_states,
679
- return_dict=return_dict,
680
- )
681
-
682
- pooled_output = outputs[1]
683
-
684
- pooled_output = self.dropout(pooled_output)
685
- logits = self.classifier(pooled_output)
686
- reshaped_logits = logits.view(-1, num_choices)
687
-
688
- loss = None
689
- if labels is not None:
690
- loss_fct = nn.CrossEntropyLoss()
691
- loss = loss_fct(reshaped_logits, labels)
692
-
693
- if not return_dict:
694
- output = (reshaped_logits,) + outputs[2:]
695
- return ((loss,) + output) if loss is not None else output
696
-
697
- return MultipleChoiceModelOutput(
698
- loss=loss,
699
- logits=reshaped_logits,
700
- hidden_states=None,
701
- attentions=None,
702
- )
703
-
704
-
705
- class BertForTokenClassification(BertPreTrainedModel):
706
- # TBD: Push in future commit
707
- pass
708
-
709
-
710
- class BertForQuestionAnswering(BertPreTrainedModel):
711
- """Bert Model with a span classification head.
712
-
713
- This is used for extractive question-answering tasks like SQuAD (a linear
714
- layers on top of the hidden states' output to compute `span start logits`
715
- and `span end logits`).
716
- """
717
-
718
- # TBD: Push in future commit
719
-
720
-
721
- ###################
722
- # FlexBert Heads
723
- ###################
724
-
725
-
726
- class FlexBertPredictionHead(nn.Module):
727
- def __init__(self, config: FlexBertConfig):
728
- super().__init__()
729
- self.config = config
730
- self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.head_pred_bias)
731
- self.act = get_act_fn(config.head_pred_act) if config.head_pred_act else nn.Identity()
732
- self.norm = (
733
- get_norm_layer(config, compiled_norm=config.compile_model) if config.head_pred_norm else nn.Identity()
734
- )
735
-
736
- def _init_weights(self, reset_params: bool = False):
737
- if reset_params:
738
- self.norm.reset_parameters()
739
- init_weights(self.config, self.dense, layer_dim=self.config.hidden_size, type_of_module=ModuleType.in_module)
740
-
741
- def reset_parameters(self):
742
- self._init_weights(reset_params=True)
743
-
744
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
745
- return self.norm(self.act(self.dense(hidden_states)))
746
-
747
-
748
- class FlexBertPoolingHead(nn.Module):
749
- def __init__(self, config: FlexBertConfig):
750
- super().__init__()
751
- self.config = config
752
- self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.head_class_bias)
753
- self.act = get_act_fn(config.head_class_act) if config.head_class_act else nn.Identity()
754
- self.norm = get_norm_layer(config) if config.head_class_norm else nn.Identity()
755
- self.drop = torch.nn.Dropout(config.head_class_dropout) if config.head_class_dropout > 0 else nn.Identity()
756
- self.pooling_type = config.pooling_type
757
-
758
- def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor:
759
- if pool:
760
- if self.pooling_type == "cls":
761
- output = hidden_states[:, 0]
762
- elif self.pooling_type == "mean":
763
- output = hidden_states.mean(dim=1)
764
- elif self.pooling_type == "max":
765
- output = hidden_states.max(dim=1)[0]
766
- else:
767
- output = hidden_states
768
-
769
- return self.drop(self.norm(self.act(self.dense(output))))
770
-
771
- def _init_weights(self, reset_params: bool = False):
772
- init_weights(self.config, self.dense, self.config.hidden_size, type_of_module=ModuleType.out_module)
773
- if reset_params and hasattr(self.norm, "reset_parameters"):
774
- self.norm.reset_parameters()
775
-
776
- def reset_parameters(self):
777
- self._init_weights(reset_params=True)
778
-
779
-
780
- ###################
781
- # FlexBert Models
782
- ###################
783
-
784
-
785
- @dataclass
786
- class MaskedLMOutput(ModelOutput):
787
- """
788
- Base class for masked language models outputs.
789
-
790
- Args:
791
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
792
- Masked language modeling (MLM) loss.
793
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
794
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
795
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
796
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
797
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
798
-
799
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
800
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
801
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
802
- sequence_length)`.
803
-
804
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
805
- heads.
806
- """
807
-
808
- loss: Optional[torch.FloatTensor] = None
809
- logits: torch.FloatTensor = None
810
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
811
- attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
812
- indices: Optional[torch.LongTensor] = None
813
- cu_seqlens: Optional[torch.LongTensor] = None
814
- max_seqlen: Optional[int] = None
815
- batch_size: Optional[int] = None
816
- seq_len: Optional[int] = None
817
- labels: Optional[torch.LongTensor] = None
818
-
819
-
820
- @dataclass
821
- class MaskedLMOutputZLoss(ModelOutput):
822
- """
823
- Base class for masked language models outputs.
824
-
825
- Args:
826
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
827
- Masked language modeling (MLM) loss.
828
- ce_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
829
- Cross entropy loss.
830
- z_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
831
- Z loss.
832
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
833
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
834
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
835
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
836
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
837
-
838
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
839
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
840
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
841
- sequence_length)`.
842
-
843
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
844
- heads.
845
- indices (`torch.LongTensor` of shape `(batch_size,)`):
846
- Indices of the tokens to be masked.
847
- """
848
-
849
- loss: Optional[torch.FloatTensor] = None
850
- ce_loss: Optional[torch.FloatTensor] = None
851
- z_loss: Optional[torch.FloatTensor] = None
852
- logits: torch.FloatTensor = None
853
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
854
- attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
855
- indices: Optional[torch.LongTensor] = None
856
- cu_seqlens: Optional[torch.LongTensor] = None
857
- max_seqlen: Optional[int] = None
858
- batch_size: Optional[int] = None
859
- seq_len: Optional[int] = None
860
- labels: Optional[torch.LongTensor] = None
861
-
862
-
863
- class FlexBertPreTrainedModel(BertPreTrainedModel):
864
- """
865
- An abstract class to handle custom weights initialization of modules
866
- """
867
-
868
- def _init_module_weights(self, module: nn.Module):
869
- """
870
- Custom weight init of modules using src.bert_layers.initialization.init_weights
871
- Currently only supports init of embedding modules
872
- """
873
- assert isinstance(module, nn.Module)
874
- if isinstance(module, nn.Embedding):
875
- init_weights(self.config, module, type_of_module=ModuleType.emb)
876
- else:
877
- raise NotImplementedError("Custom weight init for the given module is not supported")
878
-
879
-
880
- class FlexBertModel(FlexBertPreTrainedModel):
881
- """Overall BERT model.
882
-
883
- Args:
884
- config: a BertConfig class instance with the configuration to build a new model
885
-
886
- Inputs:
887
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
888
- with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
889
- `extract_features.py`, `run_classifier.py` and `run_squad.py`)
890
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
891
- types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
892
- a `sentence B` token (see BERT paper for more details).
893
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
894
- selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
895
- input sequence length in the current batch. It's the mask that we typically use for attention when
896
- a batch has varying length sentences.
897
- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
898
-
899
- Outputs: Tuple of (encoded_layers, pooled_output)
900
- `encoded_layers`: controlled by `output_all_encoded_layers` argument:
901
- - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
902
- of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
903
- encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
904
- - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
905
- to the last attention block of shape [batch_size, sequence_length, hidden_size],
906
- `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
907
- classifier pretrained on top of the hidden state associated to the first character of the
908
- input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
909
-
910
- Example usage:
911
- ```python
912
- # Already been converted into WordPiece token ids
913
- input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
914
- input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
915
- token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
916
- config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
917
- num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
918
- model = BertModel(config=config)
919
- all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
920
- ```
921
- """
922
-
923
- def __init__(self, config: FlexBertConfig):
924
- super().__init__(config)
925
- self.embeddings = get_embedding_layer(config)
926
- self.encoder = get_encoder_layer(config)
927
- if config.final_norm:
928
- # if we use prenorm attention we need to add a final norm
929
- self.final_norm = get_norm_layer(config)
930
- else:
931
- self.final_norm = None
932
- self.unpad_embeddings = config.unpad_embeddings
933
-
934
- def post_init(self):
935
- self._init_weights(reset_params=False)
936
- self._backward_compatibility_gradient_checkpointing()
937
-
938
- def get_input_embeddings(self):
939
- return self.embeddings.tok_embeddings
940
-
941
- def set_input_embeddings(self, value):
942
- self.embeddings.tok_embeddings = value
943
-
944
- def forward(
945
- self,
946
- input_ids: torch.Tensor,
947
- attention_mask: Optional[torch.Tensor] = None,
948
- position_ids: Optional[torch.Tensor] = None,
949
- indices: Optional[torch.Tensor] = None,
950
- cu_seqlens: Optional[torch.Tensor] = None,
951
- max_seqlen: Optional[int] = None,
952
- **kwargs,
953
- ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
954
- if attention_mask is None:
955
- attention_mask = torch.ones_like(input_ids)
956
-
957
- embedding_output = self.embeddings(input_ids, position_ids)
958
-
959
- encoder_outputs = self.encoder(
960
- hidden_states=embedding_output,
961
- attention_mask=attention_mask,
962
- indices=indices,
963
- cu_seqlens=cu_seqlens,
964
- max_seqlen=max_seqlen,
965
- )
966
-
967
- if self.final_norm is not None:
968
- encoder_outputs = self.final_norm(encoder_outputs)
969
- return encoder_outputs
970
-
971
- def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
972
- assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
973
- if module:
974
- self._init_module_weights(module)
975
- else:
976
- assert isinstance(reset_params, bool)
977
- self.embeddings._init_weights(reset_params=reset_params)
978
- self.encoder._init_weights(reset_params=reset_params)
979
-
980
- if reset_params and self.config.final_norm:
981
- self.final_norm.reset_parameters()
982
-
983
- def reset_parameters(self):
984
- self._init_weights(reset_params=True)
985
-
986
- def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
987
- """Returns the number of parameters in the model.
988
-
989
- Args:
990
- count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
991
- trainable: only count trainable parameters.
992
- """
993
- params = sum([_count_parameters(layer, trainable) for layer in self.encoder.layers])
994
- if count_embeddings:
995
- params += _count_parameters(self.embeddings, trainable)
996
- if hasattr(self.embeddings, "position_embeddings"):
997
- params -= _count_parameters(self.embeddings.position_embeddings, trainable)
998
- return params
999
-
1000
-
1001
- class FlexBertForMaskedLM(FlexBertPreTrainedModel):
1002
- def __init__(self, config: FlexBertConfig):
1003
- super().__init__(config)
1004
- self.bert = FlexBertModel(config)
1005
- self.head = FlexBertPredictionHead(config)
1006
-
1007
- if config.tie_word_embeddings:
1008
- decoder_weights = self.bert.embeddings.tok_embeddings.weight
1009
- else:
1010
- decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight
1011
- self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias)
1012
- self.decoder.weight = decoder_weights
1013
-
1014
- self.loss_fn = nn.CrossEntropyLoss() if not hasattr(config, "loss_function") else get_loss_fn(config)
1015
- self.fa_ce = getattr(config, "loss_function", "cross_entropy") == "fa_cross_entropy"
1016
- self.return_z_loss = config.loss_kwargs.get("return_z_loss", False)
1017
- self.unpad_embeddings = config.unpad_embeddings
1018
- self.pad_logits = config.pad_logits
1019
- self.compile_model = config.compile_model
1020
- self.masked_prediction = config.masked_prediction
1021
-
1022
- # Initialize weights and apply final processing
1023
- self._init_weights(reset_params=False)
1024
-
1025
- def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1026
- assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1027
- if module:
1028
- self._init_module_weights(module)
1029
- else:
1030
- assert isinstance(reset_params, bool)
1031
- self.bert._init_weights(reset_params=reset_params)
1032
- self.head._init_weights(reset_params=reset_params)
1033
-
1034
- # Output weights.
1035
- if not self.config.tie_word_embeddings:
1036
- init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1037
-
1038
- @classmethod
1039
- def from_composer(
1040
- cls,
1041
- pretrained_checkpoint,
1042
- state_dict=None,
1043
- cache_dir=None,
1044
- from_tf=False,
1045
- config=None,
1046
- *inputs,
1047
- **kwargs,
1048
- ):
1049
- """Load from pre-trained."""
1050
- model = cls(config, *inputs, **kwargs)
1051
- if from_tf:
1052
- raise ValueError("FlexBERT does not support loading TensorFlow weights.")
1053
-
1054
- state_dict = torch.load(pretrained_checkpoint)
1055
- # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
1056
- consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
1057
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
1058
-
1059
- if len(missing_keys) > 0:
1060
- logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
1061
- if len(unexpected_keys) > 0:
1062
- logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
1063
-
1064
- return model
1065
-
1066
- def get_output_embeddings(self):
1067
- return self.decoder
1068
-
1069
- def set_output_embeddings(self, new_embeddings):
1070
- self.decoder = new_embeddings
1071
-
1072
- @torch.no_grad()
1073
- def unpad_inputs(
1074
- self, input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, labels: torch.Tensor
1075
- ):
1076
- return unpad_input(input_ids, attention_mask, position_ids, labels)
1077
-
1078
- @torch.no_grad()
1079
- def pad_inputs(
1080
- self,
1081
- inputs: torch.Tensor,
1082
- indices: torch.Tensor,
1083
- batch_size: int,
1084
- seqlen: int,
1085
- labels: Optional[torch.Tensor] = None,
1086
- ignore_index: int = -100,
1087
- ):
1088
- return pad_input(
1089
- inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen, labels=labels, ignore_index=ignore_index
1090
- )
1091
-
1092
- @torch.compile(dynamic=True)
1093
- def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
1094
- return self.decoder(self.head(output))
1095
-
1096
- def forward(
1097
- self,
1098
- input_ids: Optional[torch.Tensor],
1099
- attention_mask: Optional[torch.Tensor] = None,
1100
- position_ids: Optional[torch.Tensor] = None,
1101
- labels: Optional[torch.Tensor] = None,
1102
- return_dict: Optional[bool] = None,
1103
- indices: Optional[torch.Tensor] = None,
1104
- cu_seqlens: Optional[torch.Tensor] = None,
1105
- max_seqlen: Optional[int] = None,
1106
- batch_size: Optional[int] = None,
1107
- seq_len: Optional[int] = None,
1108
- **kwargs,
1109
- ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1110
- # labels should be a `torch.LongTensor` of shape
1111
- # `(batch_size, sequence_length)`. These are used for computing the
1112
- # masked language modeling loss.
1113
- #
1114
- # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
1115
- # `input_ids` docstring) Tokens with indices set to `-100` are ignored
1116
- # (masked), the loss is only computed for the tokens with labels in `[0,
1117
- # ..., config.vocab_size]`
1118
- #
1119
- # Prediction scores are only computed for masked tokens and the (bs,
1120
- # seqlen) dimensions are flattened
1121
-
1122
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1123
-
1124
- if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1125
- batch_size, seq_len = input_ids.shape[:2]
1126
- input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
1127
- input_ids, attention_mask, position_ids, labels
1128
- )
1129
-
1130
- output = self.bert(
1131
- input_ids,
1132
- attention_mask=attention_mask,
1133
- position_ids=position_ids,
1134
- indices=indices,
1135
- cu_seqlens=cu_seqlens,
1136
- max_seqlen=max_seqlen,
1137
- )
1138
-
1139
- if self.masked_prediction and labels is not None:
1140
- # flatten labels and output first
1141
- labels = labels.view(-1)
1142
- output = output.view(labels.shape[0], -1)
1143
-
1144
- # then filter out the non-masked tokens
1145
- mask_tokens = labels != self.loss_fn.ignore_index
1146
- output = output[mask_tokens]
1147
- labels = labels[mask_tokens]
1148
-
1149
- if self.compile_model:
1150
- logits = self.compiled_head(output)
1151
- else:
1152
- logits = self.decoder(self.head(output))
1153
-
1154
- loss = None
1155
- if labels is not None:
1156
- if not self.masked_prediction:
1157
- labels = labels.view(-1)
1158
- logits = logits.view(labels.shape[0], -1)
1159
-
1160
- if self.return_z_loss:
1161
- loss, z_loss = self.loss_fn(logits, labels)
1162
- if self.pad_logits:
1163
- return MaskedLMOutputZLoss(
1164
- loss=loss,
1165
- ce_loss=loss.detach().clone() - z_loss,
1166
- z_loss=z_loss,
1167
- logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
1168
- hidden_states=None,
1169
- attentions=None,
1170
- )
1171
- else:
1172
- return MaskedLMOutputZLoss(
1173
- loss=loss,
1174
- ce_loss=loss.detach().clone() - z_loss,
1175
- z_loss=z_loss,
1176
- logits=logits,
1177
- hidden_states=None,
1178
- attentions=None,
1179
- indices=indices,
1180
- cu_seqlens=cu_seqlens,
1181
- max_seqlen=max_seqlen,
1182
- batch_size=batch_size,
1183
- seq_len=seq_len,
1184
- labels=labels,
1185
- )
1186
- else:
1187
- loss = self.loss_fn(logits, labels)
1188
-
1189
- if self.pad_logits:
1190
- return MaskedLMOutput(
1191
- loss=loss,
1192
- logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
1193
- hidden_states=None,
1194
- attentions=None,
1195
- )
1196
- else:
1197
- return MaskedLMOutput(
1198
- loss=loss,
1199
- logits=logits,
1200
- hidden_states=None,
1201
- attentions=None,
1202
- indices=indices,
1203
- cu_seqlens=cu_seqlens,
1204
- max_seqlen=max_seqlen,
1205
- batch_size=batch_size,
1206
- seq_len=seq_len,
1207
- labels=labels,
1208
- )
1209
-
1210
- def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs):
1211
- input_shape = input_ids.shape
1212
- effective_batch_size = input_shape[0]
1213
-
1214
- # add a dummy token
1215
- if self.config.pad_token_id is None:
1216
- raise ValueError("The PAD token should be defined for generation")
1217
-
1218
- attention_mask = torch.cat(
1219
- [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
1220
- dim=-1,
1221
- )
1222
- dummy_token = torch.full(
1223
- (effective_batch_size, 1),
1224
- self.config.pad_token_id,
1225
- dtype=torch.long,
1226
- device=input_ids.device,
1227
- )
1228
- input_ids = torch.cat([input_ids, dummy_token], dim=1)
1229
-
1230
- return {"input_ids": input_ids, "attention_mask": attention_mask}
1231
-
1232
- def get_number_parameters(
1233
- self, count_embeddings: bool = True, count_decoder: bool = False, trainable: bool = True
1234
- ) -> int:
1235
- """Returns the number of parameters in the model.
1236
-
1237
- Args:
1238
- count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
1239
- count_decoder: count the parameters in the decoder layer if weights are not tied.
1240
- trainable: only count trainable parameters.
1241
- """
1242
- params = self.bert.get_number_parameters(count_embeddings, trainable)
1243
- params += _count_parameters(self.head, trainable)
1244
- if count_decoder and not self.config.tie_word_embeddings:
1245
- params += _count_parameters(self.decoder, trainable)
1246
- return params
1247
-
1248
-
1249
- class FlexBertForSequenceClassification(FlexBertPreTrainedModel):
1250
- """Bert Model transformer with a sequence classification/regression head.
1251
-
1252
- This head is just a linear layer on top of the pooled output. Used for,
1253
- e.g., GLUE tasks.
1254
- """
1255
-
1256
- def __init__(self, config: FlexBertConfig):
1257
- super().__init__(config)
1258
- self.num_labels = config.num_labels
1259
- self.config = config
1260
-
1261
- self.bert = FlexBertModel(config)
1262
- self.head = FlexBertPoolingHead(config)
1263
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1264
-
1265
- # Initialize weights and apply final processing
1266
- self._init_weights(reset_params=False)
1267
-
1268
- def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1269
- assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1270
- if module:
1271
- self._init_module_weights(module)
1272
- else:
1273
- assert isinstance(reset_params, bool)
1274
- self.bert._init_weights(reset_params=reset_params)
1275
- self.head._init_weights(reset_params=reset_params)
1276
- init_weights(self.config, self.classifier, self.config.hidden_size, type_of_module=ModuleType.final_out)
1277
-
1278
- @classmethod
1279
- def from_composer(
1280
- cls,
1281
- pretrained_checkpoint,
1282
- state_dict=None,
1283
- cache_dir=None,
1284
- from_tf=False,
1285
- config=None,
1286
- *inputs,
1287
- **kwargs,
1288
- ):
1289
- """Load from pre-trained."""
1290
- model = cls(config, *inputs, **kwargs)
1291
- if from_tf:
1292
- raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
1293
-
1294
- state_dict = torch.load(pretrained_checkpoint)
1295
- # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
1296
- consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
1297
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
1298
-
1299
- if len(missing_keys) > 0:
1300
- logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
1301
- if len(unexpected_keys) > 0:
1302
- logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
1303
-
1304
- return model
1305
-
1306
- def forward(
1307
- self,
1308
- input_ids: Optional[torch.Tensor] = None,
1309
- attention_mask: Optional[torch.Tensor] = None,
1310
- position_ids: Optional[torch.Tensor] = None,
1311
- labels: Optional[torch.Tensor] = None,
1312
- return_dict: Optional[bool] = None,
1313
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1314
- # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1315
- # Labels for computing the sequence classification/regression loss.
1316
- # Indices should be in `[0, ..., config.num_labels - 1]`.
1317
- # If `config.num_labels == 1` a regression loss is computed
1318
- # (mean-square loss). If `config.num_labels > 1` a classification loss
1319
- # is computed (cross-entropy).
1320
-
1321
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1322
-
1323
- output = self.bert(
1324
- input_ids,
1325
- attention_mask=attention_mask,
1326
- position_ids=position_ids,
1327
- )
1328
-
1329
- pooled_output = self.head(output)
1330
- logits = self.classifier(pooled_output)
1331
-
1332
- loss = None
1333
- if labels is not None:
1334
- # Compute loss
1335
- if self.config.problem_type is None:
1336
- if self.num_labels == 1:
1337
- self.config.problem_type = "regression"
1338
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1339
- self.config.problem_type = "single_label_classification"
1340
- else:
1341
- self.config.problem_type = "multi_label_classification"
1342
-
1343
- if self.config.problem_type == "regression":
1344
- loss_fct = nn.MSELoss()
1345
- if self.num_labels == 1:
1346
- loss = loss_fct(logits.squeeze(), labels.squeeze())
1347
- else:
1348
- loss = loss_fct(logits, labels)
1349
- elif self.config.problem_type == "single_label_classification":
1350
- loss_fct = nn.CrossEntropyLoss()
1351
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1352
- elif self.config.problem_type == "multi_label_classification":
1353
- loss_fct = nn.BCEWithLogitsLoss()
1354
- loss = loss_fct(logits, labels)
1355
-
1356
- if not return_dict:
1357
- output = (logits,) + output
1358
- return ((loss,) + output) if loss is not None else output
1359
-
1360
- return SequenceClassifierOutput(
1361
- loss=loss,
1362
- logits=logits,
1363
- hidden_states=None,
1364
- attentions=None,
1365
- )
1366
-
1367
- def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1368
- """Returns the number of parameters in the model.
1369
-
1370
- Args:
1371
- count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
1372
- trainable: only count trainable parameters.
1373
- """
1374
- params = self.bert.get_number_parameters(count_embeddings, trainable)
1375
- params += _count_parameters(self.head, trainable)
1376
- params += _count_parameters(self.classifier, trainable)
1377
- return params
1378
-
1379
-
1380
- class FlexBertForMultipleChoice(FlexBertPreTrainedModel):
1381
- """
1382
- Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1383
- softmax) e.g. for RocStories/SWAG tasks.
1384
- """
1385
-
1386
- def __init__(self, config: FlexBertConfig):
1387
- super().__init__(config)
1388
- self.num_labels = config.num_labels
1389
- self.config = config
1390
-
1391
- self.bert = FlexBertModel(config)
1392
- self.head = FlexBertPoolingHead(config)
1393
-
1394
- # In multiple choice tasks, all choices are submitted in a batch, and
1395
- # we compute a logit for each option independently. The logits are then
1396
- # normalized in the forward pass to get a probability distribution over
1397
- # the choices.
1398
- self.classifier = nn.Linear(config.hidden_size, 1)
1399
-
1400
- # Initialize weights and apply final processing
1401
- self._init_weights(reset_params=False)
1402
-
1403
- def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1404
- assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1405
- if module:
1406
- self._init_module_weights(module)
1407
- else:
1408
- assert isinstance(reset_params, bool)
1409
- self.bert._init_weights(reset_params=reset_params)
1410
- self.head._init_weights(reset_params=reset_params)
1411
- init_weights(self.config, self.classifier, self.config.hidden_size, type_of_module=ModuleType.final_out)
1412
-
1413
- @classmethod
1414
- def from_composer(
1415
- cls,
1416
- pretrained_checkpoint,
1417
- state_dict=None,
1418
- cache_dir=None,
1419
- from_tf=False,
1420
- config=None,
1421
- *inputs,
1422
- **kwargs,
1423
- ):
1424
- """Load from pre-trained."""
1425
- model = cls(config, *inputs, **kwargs)
1426
- if from_tf:
1427
- raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
1428
-
1429
- state_dict = torch.load(pretrained_checkpoint)
1430
- # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
1431
- consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
1432
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
1433
-
1434
- if len(missing_keys) > 0:
1435
- logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
1436
- if len(unexpected_keys) > 0:
1437
- logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
1438
-
1439
- return model
1440
-
1441
- def forward(
1442
- self,
1443
- input_ids: Optional[torch.Tensor] = None,
1444
- attention_mask: Optional[torch.Tensor] = None,
1445
- position_ids: Optional[torch.Tensor] = None,
1446
- labels: Optional[torch.Tensor] = None,
1447
- return_dict: Optional[bool] = None,
1448
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1449
- # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1450
- # Labels for computing the sequence classification/regression loss.
1451
- # Indices should be in `[0, ..., config.num_labels - 1]`.
1452
- # If `config.num_labels == 1` a regression loss is computed
1453
- # (mean-square loss). If `config.num_labels > 1` a classification loss
1454
- # is computed (cross-entropy).
1455
-
1456
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1457
- num_choices = input_ids.shape[1]
1458
-
1459
- input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1460
- attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1461
- position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1462
-
1463
- output = self.bert(
1464
- input_ids,
1465
- attention_mask=attention_mask,
1466
- position_ids=position_ids,
1467
- )
1468
-
1469
- pooled_output = self.head(output)
1470
- logits = self.classifier(pooled_output)
1471
- reshaped_logits = logits.view(-1, num_choices)
1472
-
1473
- loss = None
1474
- if labels is not None:
1475
- loss_fct = nn.CrossEntropyLoss()
1476
- loss = loss_fct(reshaped_logits, labels)
1477
-
1478
- if not return_dict:
1479
- output = (reshaped_logits,) + output
1480
- return ((loss,) + output) if loss is not None else output
1481
-
1482
- return MultipleChoiceModelOutput(
1483
- loss=loss,
1484
- logits=reshaped_logits,
1485
- hidden_states=None,
1486
- attentions=None,
1487
- )
1488
-
1489
- def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1490
- """Returns the number of parameters in the model.
1491
-
1492
- Args:
1493
- count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
1494
- trainable: only count trainable parameters.
1495
- """
1496
- params = self.bert.get_number_parameters(count_embeddings, trainable)
1497
- params += _count_parameters(self.head, trainable)
1498
- params += _count_parameters(self.classifier, trainable)
1499
- return params
1500
-
1501
-
1502
- def init_model_from_pretrained(
1503
- pretrained_model: FlexBertModel,
1504
- new_model: FlexBertModel,
1505
- mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
1506
- ):
1507
- """
1508
- Initialize the new model from the pretrained model.
1509
-
1510
- This method uses Gopher layer scaling and Phi-style weight tiling as selected by `mode`.
1511
- The new model must have the same or more layers and the same or larger dimensions than the pretrained model.
1512
-
1513
- Args:
1514
- pretrained_model (FlexBertModel): The smaller, pre-trained model
1515
- new_model (FlexBertModel): The larger model to be initialized
1516
- mode (Union[str, TileMode]): The Phi-style weight tiling mode to use
1517
-
1518
- This function assumes that the new_model has more layers and a larger hidden size
1519
- than the pretrained_model, but the same vocabulary size.
1520
- """
1521
-
1522
- # Tile embeddings
1523
- assert isinstance(
1524
- new_model.embeddings, type(pretrained_model.embeddings)
1525
- ), f"Pretrained and new_model layers must be the same type, got {type(new_model.embeddings)} and {type(pretrained_model.embeddings)}"
1526
- assert isinstance(
1527
- new_model.embeddings,
1528
- (FlexBertAbsoluteEmbeddings, FlexBertSansPositionEmbeddings, FlexBertCompiledSansPositionEmbeddings),
1529
- ), f"Unsupported embedding layer type: {type(new_model.embeddings)}"
1530
-
1531
- tile_embedding(pretrained_model.embeddings.tok_embeddings, new_model.embeddings.tok_embeddings, mode=mode)
1532
- if isinstance(pretrained_model.embeddings, FlexBertAbsoluteEmbeddings):
1533
- tile_embedding(pretrained_model.embeddings.pos_embeddings, new_model.embeddings.pos_embeddings, mode=mode)
1534
-
1535
- if hasattr(pretrained_model.embeddings, "norm"):
1536
- tile_norm(pretrained_model.embeddings.norm, new_model.embeddings.norm, mode=mode)
1537
-
1538
- # Tile encoder layers
1539
- assert isinstance(
1540
- pretrained_model.encoder, (FlexBertUnpadEncoder, FlexBertPaddedEncoder)
1541
- ), f"Unsupported encoder layer type: {type(pretrained_model.encoder)}"
1542
- assert isinstance(
1543
- new_model.encoder, type(pretrained_model.encoder)
1544
- ), f"Pretrained and new_model encoder layers must be the same type, got {type(new_model.encoder)} and {type(pretrained_model.encoder)}"
1545
-
1546
- # Calculate the layer mapping
1547
- pretrained_layers = len(pretrained_model.encoder.layers)
1548
- new_layers = len(new_model.encoder.layers)
1549
- layer_mapping = [round(i * pretrained_layers / new_layers) for i in range(new_layers)]
1550
-
1551
- # Initialize layers
1552
- for new_model_idx, pretrained_idx in enumerate(layer_mapping):
1553
- new_model_layer = new_model.encoder.layers[new_model_idx]
1554
- pretrained_layer = pretrained_model.encoder.layers[pretrained_idx]
1555
-
1556
- # first tile the PreNorm/PostNorm layers
1557
- assert isinstance(
1558
- new_model_layer, type(pretrained_layer)
1559
- ), f"Pretrained and new_model prenorm/postnorm layers must be the same type, got {type(new_model_layer)} and {type(pretrained_layer)}"
1560
- assert isinstance(
1561
- new_model_layer,
1562
- (
1563
- FlexBertUnpadPreNormLayer,
1564
- FlexBertCompileUnpadPreNormLayer,
1565
- FlexBertUnpadParallelPreNormLayer,
1566
- FlexBertUnpadPostNormLayer,
1567
- FlexBertPaddedPreNormLayer,
1568
- FlexBertPaddedParallelPreNormLayer,
1569
- FlexBertPaddedPostNormLayer,
1570
- ),
1571
- ), f"Unsupported prenorm/postnorm layer type: {type(new_model_layer)}"
1572
-
1573
- # First tile the normalization layers
1574
- if hasattr(pretrained_layer, "attn_norm"):
1575
- tile_norm(pretrained_layer.attn_norm, new_model_layer.attn_norm, mode=mode)
1576
- if hasattr(pretrained_layer, "norm"):
1577
- tile_norm(pretrained_layer.norm, new_model_layer.norm, mode=mode)
1578
- if hasattr(pretrained_layer, "mlp_norm"):
1579
- tile_norm(pretrained_layer.mlp_norm, new_model_layer.mlp_norm, mode=mode)
1580
-
1581
- # Then tile the attention & mlp layers
1582
- assert isinstance(
1583
- new_model_layer.attn, type(pretrained_layer.attn)
1584
- ), f"Pretrained and new_model attention layers must be the same type, got {type(new_model_layer.attn)} and {type(pretrained_layer.attn)}"
1585
-
1586
- # first try the parallel attention layers
1587
- if isinstance(pretrained_layer, (FlexBertUnpadParallelPreNormLayer, FlexBertPaddedParallelPreNormLayer)):
1588
- assert isinstance(
1589
- pretrained_layer.attn,
1590
- (
1591
- FlexBertUnpadParallelAttention,
1592
- FlexBertPaddedParallelAttention,
1593
- FlexBertUnpadRopeParallelAttention,
1594
- FlexBertPaddedRopeParallelAttention,
1595
- ),
1596
- ), f"Parallel prenorm layer must have parallel attention layer: {type(pretrained_layer.attn)}"
1597
- if not isinstance(pretrained_layer.mlp, (FlexBertParallelGLU)):
1598
- raise ValueError(f"Parallel prenorm layer must have parallel MLP layer: {type(pretrained_layer.mlp)}")
1599
- tile_linear(
1600
- pretrained_layer.Wqkvff,
1601
- new_model_layer.Wqkvff,
1602
- linear_type=TileLinear.wqkvff,
1603
- mode=mode,
1604
- pretrained_attn_size=pretrained_layer.attn_size,
1605
- pretrained_mlp_size=pretrained_layer.mlp_size,
1606
- new_attn_size=new_model_layer.attn_size,
1607
- new_mlp_size=new_model_layer.mlp_size,
1608
- wqkvff_is_glu=True,
1609
- )
1610
-
1611
- # then try the fused attention layers
1612
- elif isinstance(
1613
- pretrained_layer.attn,
1614
- (
1615
- FlexBertUnpadAttention,
1616
- FlexBertPaddedAttention,
1617
- FlexBertUnpadRopeAttention,
1618
- FlexBertPaddedRopeAttention,
1619
- ),
1620
- ):
1621
- tile_linear(pretrained_layer.attn.Wqkv, new_model_layer.attn.Wqkv, linear_type=TileLinear.wqkv, mode=mode)
1622
- else:
1623
- raise ValueError(f"Unsupported attention layer type: {type(pretrained_layer.attn)}")
1624
-
1625
- # finally, tile the attention output layer
1626
- tile_linear(pretrained_layer.attn.Wo, new_model_layer.attn.Wo, linear_type=TileLinear.default, mode=mode)
1627
-
1628
- # tile the mlp layer if the model is not using parallel attention layers
1629
- if not isinstance(pretrained_layer.mlp, (FlexBertMLP, FlexBertGLU, FlexBertParallelGLU)):
1630
- raise ValueError(f"Unsupported MLP layer type: {type(pretrained_layer.mlp)}")
1631
- assert isinstance(
1632
- new_model_layer.mlp, type(pretrained_layer.mlp)
1633
- ), f"Pretrained and new_model mlp layers must be the same type, got {type(new_model_layer.mlp)} and {type(pretrained_layer.mlp)}"
1634
-
1635
- # already tiled the parallel glu layer if it exists, so only need to handle mlp & glu Wi
1636
- if isinstance(pretrained_layer.mlp, FlexBertGLU):
1637
- tile_linear(pretrained_layer.mlp.Wi, new_model_layer.mlp.Wi, linear_type=TileLinear.glu, mode=mode)
1638
- elif isinstance(pretrained_layer.mlp, FlexBertMLP):
1639
- tile_linear(pretrained_layer.mlp.Wi, new_model_layer.mlp.Wi, linear_type=TileLinear.default, mode=mode)
1640
- # tile the output for both ParallelGLU and MLP/GLU
1641
- tile_linear(pretrained_layer.mlp.Wo, new_model_layer.mlp.Wo, linear_type=TileLinear.default, mode=mode)
1642
-
1643
-
1644
- def init_mlm_model_from_pretrained(
1645
- config: FlexBertConfig,
1646
- pretrained_model: FlexBertForMaskedLM,
1647
- new_model: FlexBertForMaskedLM,
1648
- mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
1649
- ):
1650
- """
1651
- Initialize the new model from the pretrained model.
1652
-
1653
- This method uses Gopher layer scaling and Phi-style weight tiling as selected by `mode`.
1654
- The new model must have the same or more layers and the same or larger dimensions than the pretrained model.
1655
-
1656
- Args:
1657
- config (FlexBertConfig): The configuration of the new_model
1658
- pretrained_model (FlexBertForMaskedLM): The smaller, pre-trained model
1659
- new_model (FlexBertForMaskedLM): The larger model to be initialized from the pretrained model
1660
- mode (Union[str, TileMode]): The Phi-style weight tiling mode to use
1661
-
1662
- This function assumes that the new_model has more layers and a larger hidden size
1663
- than the pretrained_model, but the same vocabulary size.
1664
- """
1665
- init_model_from_pretrained(pretrained_model.bert, new_model.bert, mode=mode)
1666
-
1667
- # TODO: uncomment this when the repo is turned into a pip installable package
1668
- # if not isinstance(pretrained_model.head, FlexBertPredictionHead):
1669
- # raise ValueError(f"Pretrained model must have a prediction head: {type(pretrained_model.head)}")
1670
- # if not isinstance(new_model.head, FlexBertPredictionHead):
1671
- # raise ValueError(f"New model must have a prediction head: {type(new_model.head)}")
1672
-
1673
- # tile the prediction head
1674
- tile_linear(pretrained_model.head.dense, new_model.head.dense, linear_type=TileLinear.default, mode=mode)
1675
- tile_norm(pretrained_model.head.norm, new_model.head.norm, mode=mode)
1676
-
1677
- # setup weight tying
1678
- if config.tie_word_embeddings:
1679
- new_model.decoder.weight = new_model.bert.embeddings.tok_embeddings.weight
1680
- tile_linear(
1681
- pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode, bias_only=True
1682
- )
1683
- else:
1684
- tile_linear(pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_flexbert.py CHANGED
@@ -69,8 +69,8 @@ from transformers.models.bert.modeling_bert import BertPreTrainedModel
69
 
70
  from bert_padding import index_put_first_axis
71
 
72
- from activation import get_act_fn
73
- from attention import (
74
  FlexBertPaddedAttention,
75
  FlexBertPaddedParallelAttention,
76
  FlexBertPaddedRopeAttention,
@@ -80,15 +80,15 @@ from attention import (
80
  FlexBertUnpadRopeAttention,
81
  FlexBertUnpadRopeParallelAttention,
82
  )
83
- from configuration_bert import FlexBertConfig
84
- from embeddings import (
85
  BertAlibiEmbeddings,
86
  FlexBertAbsoluteEmbeddings,
87
  FlexBertCompiledSansPositionEmbeddings,
88
  FlexBertSansPositionEmbeddings,
89
  get_embedding_layer,
90
  )
91
- from initialization import (
92
  ModuleType,
93
  TileLinear,
94
  TileMode,
@@ -97,7 +97,7 @@ from initialization import (
97
  tile_linear,
98
  tile_norm,
99
  )
100
- from layers import (
101
  BertAlibiEncoder,
102
  BertPooler,
103
  BertPredictionHeadTransform,
@@ -112,9 +112,10 @@ from layers import (
112
  FlexBertUnpadPreNormLayer,
113
  get_encoder_layer,
114
  )
115
- from mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
116
- from normalization import get_norm_layer
117
- from padding import pad_input, unpad_input
 
118
 
119
  logger = logging.getLogger(__name__)
120
 
@@ -866,16 +867,14 @@ class FlexBertPreTrainedModel(BertPreTrainedModel):
866
 
867
  def _init_module_weights(self, module: nn.Module):
868
  """
869
- Custom weight init of modules using initialization.init_weights
870
  Currently only supports init of embedding modules
871
  """
872
  assert isinstance(module, nn.Module)
873
  if isinstance(module, nn.Embedding):
874
  init_weights(self.config, module, type_of_module=ModuleType.emb)
875
  else:
876
- print(module)
877
- print("Custom weight init for the given module is not supported, please fix")
878
- # raise NotImplementedError("Custom weight init for the given module is not supported")
879
 
880
 
881
  class FlexBertModel(FlexBertPreTrainedModel):
@@ -968,7 +967,7 @@ class FlexBertModel(FlexBertPreTrainedModel):
968
  if self.final_norm is not None:
969
  encoder_outputs = self.final_norm(encoder_outputs)
970
  return encoder_outputs
971
-
972
  def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
973
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
974
  if module:
@@ -1012,6 +1011,7 @@ class FlexBertForMaskedLM(FlexBertPreTrainedModel):
1012
  self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias)
1013
  self.decoder.weight = decoder_weights
1014
 
 
1015
  self.fa_ce = getattr(config, "loss_function", "cross_entropy") == "fa_cross_entropy"
1016
  self.return_z_loss = config.loss_kwargs.get("return_z_loss", False)
1017
  self.unpad_embeddings = config.unpad_embeddings
 
69
 
70
  from bert_padding import index_put_first_axis
71
 
72
+ from src.bert_layers.activation import get_act_fn
73
+ from src.bert_layers.attention import (
74
  FlexBertPaddedAttention,
75
  FlexBertPaddedParallelAttention,
76
  FlexBertPaddedRopeAttention,
 
80
  FlexBertUnpadRopeAttention,
81
  FlexBertUnpadRopeParallelAttention,
82
  )
83
+ from src.bert_layers.configuration_bert import FlexBertConfig
84
+ from src.bert_layers.embeddings import (
85
  BertAlibiEmbeddings,
86
  FlexBertAbsoluteEmbeddings,
87
  FlexBertCompiledSansPositionEmbeddings,
88
  FlexBertSansPositionEmbeddings,
89
  get_embedding_layer,
90
  )
91
+ from src.bert_layers.initialization import (
92
  ModuleType,
93
  TileLinear,
94
  TileMode,
 
97
  tile_linear,
98
  tile_norm,
99
  )
100
+ from src.bert_layers.layers import (
101
  BertAlibiEncoder,
102
  BertPooler,
103
  BertPredictionHeadTransform,
 
112
  FlexBertUnpadPreNormLayer,
113
  get_encoder_layer,
114
  )
115
+ from src.bert_layers.loss import get_loss_fn
116
+ from src.bert_layers.mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
117
+ from src.bert_layers.normalization import get_norm_layer
118
+ from src.bert_layers.padding import pad_input, unpad_input
119
 
120
  logger = logging.getLogger(__name__)
121
 
 
867
 
868
  def _init_module_weights(self, module: nn.Module):
869
  """
870
+ Custom weight init of modules using src.bert_layers.initialization.init_weights
871
  Currently only supports init of embedding modules
872
  """
873
  assert isinstance(module, nn.Module)
874
  if isinstance(module, nn.Embedding):
875
  init_weights(self.config, module, type_of_module=ModuleType.emb)
876
  else:
877
+ raise NotImplementedError("Custom weight init for the given module is not supported")
 
 
878
 
879
 
880
  class FlexBertModel(FlexBertPreTrainedModel):
 
967
  if self.final_norm is not None:
968
  encoder_outputs = self.final_norm(encoder_outputs)
969
  return encoder_outputs
970
+
971
  def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
972
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
973
  if module:
 
1011
  self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias)
1012
  self.decoder.weight = decoder_weights
1013
 
1014
+ self.loss_fn = nn.CrossEntropyLoss() if not hasattr(config, "loss_function") else get_loss_fn(config)
1015
  self.fa_ce = getattr(config, "loss_function", "cross_entropy") == "fa_cross_entropy"
1016
  self.return_z_loss = config.loss_kwargs.get("return_z_loss", False)
1017
  self.unpad_embeddings = config.unpad_embeddings