# coding=utf-8 """PyTorch BERT model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import copy import json import math import logging import tarfile import tempfile import shutil import numpy as np from scipy.stats import truncnorm import torch from torch import nn from torch.nn import CrossEntropyLoss, MSELoss import torch.nn.functional as F from .file_utils import cached_path from .loss import LabelSmoothingLoss logger = logging.getLogger(__name__) PRETRAINED_MODEL_ARCHIVE_MAP = { 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", } CONFIG_NAME = 'bert_config.json' WEIGHTS_NAME = 'pytorch_model.bin' def gelu(x): """Implementation of the gelu activation function. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) """ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) def swish(x): return x * torch.sigmoid(x) ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} class BertConfig(object): """Configuration class to store the configuration of a `BertModel`. """ def __init__(self, vocab_size_or_config_json_file, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, relax_projection=0, new_pos_ids=False, initializer_range=0.02, task_idx=None, fp32_embedding=False, ffn_type=0, label_smoothing=None, num_qkv=0, seg_emb=False): """Constructs BertConfig. Args: vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. hidden_size: Size of the encoder layers and the pooler layer. num_hidden_layers: Number of hidden layers in the Transformer encoder. num_attention_heads: Number of attention heads for each attention layer in the Transformer encoder. intermediate_size: The size of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. hidden_act: The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu" and "swish" are supported. hidden_dropout_prob: The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. attention_probs_dropout_prob: The dropout ratio for the attention probabilities. max_position_embeddings: The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). type_vocab_size: The vocabulary size of the `token_type_ids` passed into `BertModel`. initializer_range: The sttdev of the truncated_normal_initializer for initializing all weight matrices. """ if isinstance(vocab_size_or_config_json_file, str): with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: json_config = json.loads(reader.read()) for key, value in json_config.items(): self.__dict__[key] = value elif isinstance(vocab_size_or_config_json_file, int): self.vocab_size = vocab_size_or_config_json_file self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.relax_projection = relax_projection self.new_pos_ids = new_pos_ids self.initializer_range = initializer_range self.task_idx = task_idx self.fp32_embedding = fp32_embedding self.ffn_type = ffn_type self.label_smoothing = label_smoothing self.num_qkv = num_qkv self.seg_emb = seg_emb else: raise ValueError("First argument must be either a vocabulary size (int)" "or the path to a pretrained model config file (str)") @classmethod def from_dict(cls, json_object): """Constructs a `BertConfig` from a Python dictionary of parameters.""" config = BertConfig(vocab_size_or_config_json_file=-1) for key, value in json_object.items(): config.__dict__[key] = value return config @classmethod def from_json_file(cls, json_file): """Constructs a `BertConfig` from a json file of parameters.""" with open(json_file, "r", encoding='utf-8') as reader: text = reader.read() return cls.from_dict(json.loads(text)) def __repr__(self): return str(self.to_json_string()) def to_dict(self): """Serializes this instance to a Python dictionary.""" output = copy.deepcopy(self.__dict__) return output def to_json_string(self): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" try: from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm except ImportError: print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.") class BertLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-5): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(BertLayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias class PositionalEmbedding(nn.Module): def __init__(self, demb): super(PositionalEmbedding, self).__init__() self.demb = demb inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) self.register_buffer('inv_freq', inv_freq) def forward(self, pos_seq, bsz=None): sinusoid_inp = torch.ger(pos_seq, self.inv_freq) pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) if bsz is not None: return pos_emb[:, None, :].expand(-1, bsz, -1) else: return pos_emb[:, None, :] class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. """ def __init__(self, config): super(BertEmbeddings, self).__init__() self.word_embeddings = nn.Embedding( config.vocab_size, config.hidden_size) self.token_type_embeddings = nn.Embedding( config.type_vocab_size, config.hidden_size) if hasattr(config, 'fp32_embedding'): self.fp32_embedding = config.fp32_embedding else: self.fp32_embedding = False if hasattr(config, 'new_pos_ids') and config.new_pos_ids: self.num_pos_emb = 4 else: self.num_pos_emb = 1 self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.hidden_size*self.num_pos_emb) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None, position_ids=None, task_idx=None): seq_length = input_ids.size(1) if position_ids is None: position_ids = torch.arange( seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) if self.num_pos_emb > 1: num_batch = position_embeddings.size(0) num_pos = position_embeddings.size(1) position_embeddings = position_embeddings.view( num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] embeddings = words_embeddings + position_embeddings + token_type_embeddings if self.fp32_embedding: embeddings = embeddings.half() embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class BertSelfAttention(nn.Module): def __init__(self, config): super(BertSelfAttention, self).__init__() if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int( config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size if hasattr(config, 'num_qkv') and (config.num_qkv > 1): self.num_qkv = config.num_qkv else: self.num_qkv = 1 self.query = nn.Linear( config.hidden_size, self.all_head_size*self.num_qkv) self.key = nn.Linear(config.hidden_size, self.all_head_size*self.num_qkv) self.value = nn.Linear( config.hidden_size, self.all_head_size*self.num_qkv) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.uni_debug_flag = True if os.getenv( 'UNI_DEBUG_FLAG', '') else False if self.uni_debug_flag: self.register_buffer('debug_attention_probs', torch.zeros((512, 512))) if hasattr(config, 'seg_emb') and config.seg_emb: self.b_q_s = nn.Parameter(torch.zeros( 1, self.num_attention_heads, 1, self.attention_head_size)) self.seg_emb = nn.Embedding( config.type_vocab_size, self.all_head_size) else: self.b_q_s = None self.seg_emb = None def transpose_for_scores(self, x, mask_qkv=None): if self.num_qkv > 1: sz = x.size()[:-1] + (self.num_qkv, self.num_attention_heads, self.all_head_size) # (batch, pos, num_qkv, head, head_hid) x = x.view(*sz) if mask_qkv is None: x = x[:, :, 0, :, :] elif isinstance(mask_qkv, int): x = x[:, :, mask_qkv, :, :] else: # mask_qkv: (batch, pos) if mask_qkv.size(1) > sz[1]: mask_qkv = mask_qkv[:, :sz[1]] # -> x: (batch, pos, head, head_hid) x = x.gather(2, mask_qkv.view(sz[0], sz[1], 1, 1, 1).expand( sz[0], sz[1], 1, sz[3], sz[4])).squeeze(2) else: sz = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # (batch, pos, head, head_hid) x = x.view(*sz) # (batch, head, pos, head_hid) return x.permute(0, 2, 1, 3) def forward(self, hidden_states, attention_mask, history_states=None, mask_qkv=None, seg_ids=None): if history_states is None: mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) else: x_states = torch.cat((history_states, hidden_states), dim=1) mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(x_states) mixed_value_layer = self.value(x_states) query_layer = self.transpose_for_scores(mixed_query_layer, mask_qkv) key_layer = self.transpose_for_scores(mixed_key_layer, mask_qkv) value_layer = self.transpose_for_scores(mixed_value_layer, mask_qkv) # Take the dot product between "query" and "key" to get the raw attention scores. # (batch, head, pos, pos) attention_scores = torch.matmul( query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2)) if self.seg_emb is not None: seg_rep = self.seg_emb(seg_ids) # (batch, pos, head, head_hid) seg_rep = seg_rep.view(seg_rep.size(0), seg_rep.size( 1), self.num_attention_heads, self.attention_head_size) qs = torch.einsum('bnih,bjnh->bnij', query_layer+self.b_q_s, seg_rep) attention_scores = attention_scores + qs # attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) if self.uni_debug_flag: _pos = attention_probs.size(-1) self.debug_attention_probs[:_pos, :_pos].copy_( attention_probs[0].mean(0).view(_pos, _pos)) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[ :-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) return context_layer class BertSelfOutput(nn.Module): def __init__(self, config): super(BertSelfOutput, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertAttention(nn.Module): def __init__(self, config): super(BertAttention, self).__init__() self.self = BertSelfAttention(config) self.output = BertSelfOutput(config) def forward(self, input_tensor, attention_mask, history_states=None, mask_qkv=None, seg_ids=None): self_output = self.self( input_tensor, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids) attention_output = self.output(self_output, input_tensor) return attention_output class BertIntermediate(nn.Module): def __init__(self, config): super(BertIntermediate, self).__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.intermediate_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) else config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class BertOutput(nn.Module): def __init__(self, config): super(BertOutput, self).__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class TransformerFFN(nn.Module): def __init__(self, config): super(TransformerFFN, self).__init__() self.ffn_type = config.ffn_type assert self.ffn_type in (1, 2) if self.ffn_type in (1, 2): self.wx0 = nn.Linear(config.hidden_size, config.hidden_size) if self.ffn_type in (2,): self.wx1 = nn.Linear(config.hidden_size, config.hidden_size) if self.ffn_type in (1, 2): self.output = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, x): if self.ffn_type in (1, 2): x0 = self.wx0(x) if self.ffn_type == 1: x1 = x elif self.ffn_type == 2: x1 = self.wx1(x) out = self.output(x0 * x1) out = self.dropout(out) out = self.LayerNorm(out + x) return out class BertLayer(nn.Module): def __init__(self, config): super(BertLayer, self).__init__() self.attention = BertAttention(config) self.ffn_type = config.ffn_type if self.ffn_type: self.ffn = TransformerFFN(config) else: self.intermediate = BertIntermediate(config) self.output = BertOutput(config) def forward(self, hidden_states, attention_mask, history_states=None, mask_qkv=None, seg_ids=None): attention_output = self.attention( hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids) if self.ffn_type: layer_output = self.ffn(attention_output) else: intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output class BertEncoder(nn.Module): def __init__(self, config): super(BertEncoder, self).__init__() layer = BertLayer(config) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, seg_ids=None): # history embedding and encoded layer must be simultanously given assert (prev_embedding is None) == (prev_encoded_layers is None) all_encoder_layers = [] if (prev_embedding is not None) and (prev_encoded_layers is not None): history_states = prev_embedding for i, layer_module in enumerate(self.layer): hidden_states = layer_module( hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) if prev_encoded_layers is not None: history_states = prev_encoded_layers[i] else: for layer_module in self.layer: hidden_states = layer_module( hidden_states, attention_mask, mask_qkv=mask_qkv, seg_ids=seg_ids) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) if not output_all_encoded_layers: all_encoder_layers.append(hidden_states) return all_encoder_layers class BertPooler(nn.Module): def __init__(self, config): super(BertPooler, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertPredictionHeadTransform(nn.Module): def __init__(self, config): super(BertPredictionHeadTransform, self).__init__() self.transform_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) else config.hidden_act hid_size = config.hidden_size if hasattr(config, 'relax_projection') and (config.relax_projection > 1): hid_size *= config.relax_projection self.dense = nn.Linear(config.hidden_size, hid_size) self.LayerNorm = BertLayerNorm(hid_size, eps=1e-5) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class BertLMPredictionHead(nn.Module): def __init__(self, config, bert_model_embedding_weights): super(BertLMPredictionHead, self).__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0), bias=False) self.decoder.weight = bert_model_embedding_weights self.bias = nn.Parameter(torch.zeros( bert_model_embedding_weights.size(0))) if hasattr(config, 'relax_projection') and (config.relax_projection > 1): self.relax_projection = config.relax_projection else: self.relax_projection = 0 self.fp32_embedding = config.fp32_embedding def convert_to_type(tensor): if self.fp32_embedding: return tensor.half() else: return tensor self.type_converter = convert_to_type self.converted = False def forward(self, hidden_states, task_idx=None): if not self.converted: self.converted = True if self.fp32_embedding: self.transform.half() hidden_states = self.transform(self.type_converter(hidden_states)) if self.relax_projection > 1: num_batch = hidden_states.size(0) num_pos = hidden_states.size(1) # (batch, num_pos, relax_projection*hid) -> (batch, num_pos, relax_projection, hid) -> (batch, num_pos, hid) hidden_states = hidden_states.view( num_batch, num_pos, self.relax_projection, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] if self.fp32_embedding: hidden_states = F.linear(self.type_converter(hidden_states), self.type_converter( self.decoder.weight), self.type_converter(self.bias)) else: hidden_states = self.decoder(hidden_states) + self.bias return hidden_states class BertOnlyMLMHead(nn.Module): def __init__(self, config, bert_model_embedding_weights): super(BertOnlyMLMHead, self).__init__() self.predictions = BertLMPredictionHead( config, bert_model_embedding_weights) def forward(self, sequence_output): prediction_scores = self.predictions(sequence_output) return prediction_scores class BertOnlyNSPHead(nn.Module): def __init__(self, config): super(BertOnlyNSPHead, self).__init__() self.seq_relationship = nn.Linear(config.hidden_size, 2) def forward(self, pooled_output): seq_relationship_score = self.seq_relationship(pooled_output) return seq_relationship_score class BertPreTrainingHeads(nn.Module): def __init__(self, config, bert_model_embedding_weights, num_labels=2): super(BertPreTrainingHeads, self).__init__() self.predictions = BertLMPredictionHead( config, bert_model_embedding_weights) self.seq_relationship = nn.Linear(config.hidden_size, num_labels) def forward(self, sequence_output, pooled_output, task_idx=None): prediction_scores = self.predictions(sequence_output, task_idx) if pooled_output is None: seq_relationship_score = None else: seq_relationship_score = self.seq_relationship(pooled_output) return prediction_scores, seq_relationship_score class PreTrainedBertModel(nn.Module): """ An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. """ def __init__(self, config, *inputs, **kwargs): super(PreTrainedBertModel, self).__init__() if not isinstance(config, BertConfig): raise ValueError( "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " "To create a model from a Google pretrained model use " "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ )) self.config = config def init_bert_weights(self, module): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, BertLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() @classmethod def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: pretrained_model_name: either: - a str with the name of a pre-trained model to load selected in the list of: . `bert-base-uncased` . `bert-large-uncased` . `bert-base-cased` . `bert-base-multilingual` . `bert-base-chinese` - a path or url to a pretrained model archive containing: . `bert_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance cache_dir: an optional path to a folder in which the pre-trained models will be cached. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] else: archive_file = pretrained_model_name # redirect to the cache, if necessary try: resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir) except FileNotFoundError: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name, ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), archive_file)) return None if resolved_archive_file == archive_file: logger.info("loading archive file {}".format(archive_file)) else: logger.info("loading archive file {} from cache at {}".format( archive_file, resolved_archive_file)) tempdir = None if os.path.isdir(resolved_archive_file): serialization_dir = resolved_archive_file else: # Extract archive to temp dir tempdir = tempfile.mkdtemp() logger.info("extracting archive file {} to temp dir {}".format( resolved_archive_file, tempdir)) with tarfile.open(resolved_archive_file, 'r:gz') as archive: archive.extractall(tempdir) serialization_dir = tempdir # Load config if ('config_path' in kwargs) and kwargs['config_path']: config_file = kwargs['config_path'] else: config_file = os.path.join(serialization_dir, CONFIG_NAME) config = BertConfig.from_json_file(config_file) # define new type_vocab_size (there might be different numbers of segment ids) if 'type_vocab_size' in kwargs: config.type_vocab_size = kwargs['type_vocab_size'] # define new relax_projection if ('relax_projection' in kwargs) and kwargs['relax_projection']: config.relax_projection = kwargs['relax_projection'] # new position embedding if ('new_pos_ids' in kwargs) and kwargs['new_pos_ids']: config.new_pos_ids = kwargs['new_pos_ids'] # define new relax_projection if ('task_idx' in kwargs) and kwargs['task_idx']: config.task_idx = kwargs['task_idx'] # define new max position embedding for length expansion if ('max_position_embeddings' in kwargs) and kwargs['max_position_embeddings']: config.max_position_embeddings = kwargs['max_position_embeddings'] # use fp32 for embeddings if ('fp32_embedding' in kwargs) and kwargs['fp32_embedding']: config.fp32_embedding = kwargs['fp32_embedding'] # type of FFN in transformer blocks if ('ffn_type' in kwargs) and kwargs['ffn_type']: config.ffn_type = kwargs['ffn_type'] # label smoothing if ('label_smoothing' in kwargs) and kwargs['label_smoothing']: config.label_smoothing = kwargs['label_smoothing'] # dropout if ('hidden_dropout_prob' in kwargs) and kwargs['hidden_dropout_prob']: config.hidden_dropout_prob = kwargs['hidden_dropout_prob'] if ('attention_probs_dropout_prob' in kwargs) and kwargs['attention_probs_dropout_prob']: config.attention_probs_dropout_prob = kwargs['attention_probs_dropout_prob'] # different QKV if ('num_qkv' in kwargs) and kwargs['num_qkv']: config.num_qkv = kwargs['num_qkv'] # segment embedding for self-attention if ('seg_emb' in kwargs) and kwargs['seg_emb']: config.seg_emb = kwargs['seg_emb'] # initialize word embeddings _word_emb_map = None if ('word_emb_map' in kwargs) and kwargs['word_emb_map']: _word_emb_map = kwargs['word_emb_map'] logger.info("Model config {}".format(config)) # clean the arguments in kwargs for arg_clean in ('config_path', 'type_vocab_size', 'relax_projection', 'new_pos_ids', 'task_idx', 'max_position_embeddings', 'fp32_embedding', 'ffn_type', 'label_smoothing', 'hidden_dropout_prob', 'attention_probs_dropout_prob', 'num_qkv', 'seg_emb', 'word_emb_map'): if arg_clean in kwargs: del kwargs[arg_clean] # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None: weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) state_dict = torch.load(weights_path) old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if 'gamma' in key: new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) # initialize new segment embeddings _k = 'bert.embeddings.token_type_embeddings.weight' if (_k in state_dict) and (config.type_vocab_size != state_dict[_k].shape[0]): logger.info("config.type_vocab_size != state_dict[bert.embeddings.token_type_embeddings.weight] ({0} != {1})".format( config.type_vocab_size, state_dict[_k].shape[0])) if config.type_vocab_size > state_dict[_k].shape[0]: # state_dict[_k].data = state_dict[_k].data.resize_(config.type_vocab_size, state_dict[_k].shape[1]) state_dict[_k].resize_( config.type_vocab_size, state_dict[_k].shape[1]) # L2R if config.type_vocab_size >= 3: state_dict[_k].data[2, :].copy_(state_dict[_k].data[0, :]) # R2L if config.type_vocab_size >= 4: state_dict[_k].data[3, :].copy_(state_dict[_k].data[0, :]) # S2S if config.type_vocab_size >= 6: state_dict[_k].data[4, :].copy_(state_dict[_k].data[0, :]) state_dict[_k].data[5, :].copy_(state_dict[_k].data[1, :]) if config.type_vocab_size >= 7: state_dict[_k].data[6, :].copy_(state_dict[_k].data[1, :]) elif config.type_vocab_size < state_dict[_k].shape[0]: state_dict[_k].data = state_dict[_k].data[:config.type_vocab_size, :] _k = 'bert.embeddings.position_embeddings.weight' n_config_pos_emb = 4 if config.new_pos_ids else 1 if (_k in state_dict) and (n_config_pos_emb*config.hidden_size != state_dict[_k].shape[1]): logger.info("n_config_pos_emb*config.hidden_size != state_dict[bert.embeddings.position_embeddings.weight] ({0}*{1} != {2})".format( n_config_pos_emb, config.hidden_size, state_dict[_k].shape[1])) assert state_dict[_k].shape[1] % config.hidden_size == 0 n_state_pos_emb = int(state_dict[_k].shape[1]/config.hidden_size) assert (n_state_pos_emb == 1) != (n_config_pos_emb == 1), "!!!!n_state_pos_emb == 1 xor n_config_pos_emb == 1!!!!" if n_state_pos_emb == 1: state_dict[_k].data = state_dict[_k].data.unsqueeze(1).repeat( 1, n_config_pos_emb, 1).reshape((config.max_position_embeddings, n_config_pos_emb*config.hidden_size)) elif n_config_pos_emb == 1: if hasattr(config, 'task_idx') and (config.task_idx is not None) and (0 <= config.task_idx <= 3): _task_idx = config.task_idx else: _task_idx = 0 state_dict[_k].data = state_dict[_k].data.view( config.max_position_embeddings, n_state_pos_emb, config.hidden_size).select(1, _task_idx) # initialize new position embeddings _k = 'bert.embeddings.position_embeddings.weight' if _k in state_dict and config.max_position_embeddings != state_dict[_k].shape[0]: logger.info("config.max_position_embeddings != state_dict[bert.embeddings.position_embeddings.weight] ({0} - {1})".format( config.max_position_embeddings, state_dict[_k].shape[0])) if config.max_position_embeddings > state_dict[_k].shape[0]: old_size = state_dict[_k].shape[0] # state_dict[_k].data = state_dict[_k].data.resize_(config.max_position_embeddings, state_dict[_k].shape[1]) state_dict[_k].resize_( config.max_position_embeddings, state_dict[_k].shape[1]) start = old_size while start < config.max_position_embeddings: chunk_size = min( old_size, config.max_position_embeddings - start) state_dict[_k].data[start:start+chunk_size, :].copy_(state_dict[_k].data[:chunk_size, :]) start += chunk_size elif config.max_position_embeddings < state_dict[_k].shape[0]: state_dict[_k].data = state_dict[_k].data[:config.max_position_embeddings, :] # initialize relax projection _k = 'cls.predictions.transform.dense.weight' n_config_relax = 1 if (config.relax_projection < 1) else config.relax_projection if (_k in state_dict) and (n_config_relax*config.hidden_size != state_dict[_k].shape[0]): logger.info("n_config_relax*config.hidden_size != state_dict[cls.predictions.transform.dense.weight] ({0}*{1} != {2})".format( n_config_relax, config.hidden_size, state_dict[_k].shape[0])) assert state_dict[_k].shape[0] % config.hidden_size == 0 n_state_relax = int(state_dict[_k].shape[0]/config.hidden_size) assert (n_state_relax == 1) != (n_config_relax == 1), "!!!!n_state_relax == 1 xor n_config_relax == 1!!!!" if n_state_relax == 1: _k = 'cls.predictions.transform.dense.weight' state_dict[_k].data = state_dict[_k].data.unsqueeze(0).repeat( n_config_relax, 1, 1).reshape((n_config_relax*config.hidden_size, config.hidden_size)) for _k in ('cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias'): state_dict[_k].data = state_dict[_k].data.unsqueeze( 0).repeat(n_config_relax, 1).view(-1) elif n_config_relax == 1: if hasattr(config, 'task_idx') and (config.task_idx is not None) and (0 <= config.task_idx <= 3): _task_idx = config.task_idx else: _task_idx = 0 _k = 'cls.predictions.transform.dense.weight' state_dict[_k].data = state_dict[_k].data.view( n_state_relax, config.hidden_size, config.hidden_size).select(0, _task_idx) for _k in ('cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias'): state_dict[_k].data = state_dict[_k].data.view( n_state_relax, config.hidden_size).select(0, _task_idx) # initialize QKV _all_head_size = config.num_attention_heads * \ int(config.hidden_size / config.num_attention_heads) n_config_num_qkv = 1 if (config.num_qkv < 1) else config.num_qkv for qkv_name in ('query', 'key', 'value'): _k = 'bert.encoder.layer.0.attention.self.{0}.weight'.format( qkv_name) if (_k in state_dict) and (n_config_num_qkv*_all_head_size != state_dict[_k].shape[0]): logger.info("n_config_num_qkv*_all_head_size != state_dict[_k] ({0}*{1} != {2})".format( n_config_num_qkv, _all_head_size, state_dict[_k].shape[0])) for layer_idx in range(config.num_hidden_layers): _k = 'bert.encoder.layer.{0}.attention.self.{1}.weight'.format( layer_idx, qkv_name) assert state_dict[_k].shape[0] % _all_head_size == 0 n_state_qkv = int(state_dict[_k].shape[0]/_all_head_size) assert (n_state_qkv == 1) != (n_config_num_qkv == 1), "!!!!n_state_qkv == 1 xor n_config_num_qkv == 1!!!!" if n_state_qkv == 1: _k = 'bert.encoder.layer.{0}.attention.self.{1}.weight'.format( layer_idx, qkv_name) state_dict[_k].data = state_dict[_k].data.unsqueeze(0).repeat( n_config_num_qkv, 1, 1).reshape((n_config_num_qkv*_all_head_size, _all_head_size)) _k = 'bert.encoder.layer.{0}.attention.self.{1}.bias'.format( layer_idx, qkv_name) state_dict[_k].data = state_dict[_k].data.unsqueeze( 0).repeat(n_config_num_qkv, 1).view(-1) elif n_config_num_qkv == 1: if hasattr(config, 'task_idx') and (config.task_idx is not None) and (0 <= config.task_idx <= 3): _task_idx = config.task_idx else: _task_idx = 0 assert _task_idx != 3, "[INVALID] _task_idx=3: n_config_num_qkv=1 (should be 2)" if _task_idx == 0: _qkv_idx = 0 else: _qkv_idx = 1 _k = 'bert.encoder.layer.{0}.attention.self.{1}.weight'.format( layer_idx, qkv_name) state_dict[_k].data = state_dict[_k].data.view( n_state_qkv, _all_head_size, _all_head_size).select(0, _qkv_idx) _k = 'bert.encoder.layer.{0}.attention.self.{1}.bias'.format( layer_idx, qkv_name) state_dict[_k].data = state_dict[_k].data.view( n_state_qkv, _all_head_size).select(0, _qkv_idx) if _word_emb_map: _k = 'bert.embeddings.word_embeddings.weight' for _tgt, _src in _word_emb_map: state_dict[_k].data[_tgt, :].copy_( state_dict[_k].data[_src, :]) missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(model, prefix='' if hasattr(model, 'bert') else 'bert.') model.missing_keys = missing_keys if len(missing_keys) > 0: logger.info("Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info("Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: logger.info('\n'.join(error_msgs)) if tempdir: # Clean up temp dir shutil.rmtree(tempdir) return model class BertModel(PreTrainedBertModel): """BERT model ("Bidirectional Embedding Representations from a Transformer"). Params: config: a BertConfig class instance with the configuration to build a new model Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`) `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. Outputs: Tuple of (encoded_layers, pooled_output) `encoded_layers`: controled by `output_all_encoded_layers` argument: - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding to the last attention block of shape [batch_size, sequence_length, hidden_size], `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper). ``` """ def __init__(self, config): super(BertModel, self).__init__(config) self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) self.apply(self.init_bert_weights) def rescale_some_parameters(self): for layer_id, layer in enumerate(self.encoder.layer): layer.attention.output.dense.weight.data.div_( math.sqrt(2.0*(layer_id + 1))) layer.output.dense.weight.data.div_(math.sqrt(2.0*(layer_id + 1))) def get_extended_attention_mask(self, input_ids, token_type_ids, attention_mask): if attention_mask is None: attention_mask = torch.ones_like(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. if attention_mask.dim() == 2: extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) elif attention_mask.dim() == 3: extended_attention_mask = attention_mask.unsqueeze(1) else: raise NotImplementedError # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to( dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, mask_qkv=None, task_idx=None): extended_attention_mask = self.get_extended_attention_mask( input_ids, token_type_ids, attention_mask) embedding_output = self.embeddings( input_ids, token_type_ids, task_idx=task_idx) encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers, mask_qkv=mask_qkv, seg_ids=token_type_ids) sequence_output = encoded_layers[-1] pooled_output = self.pooler(sequence_output) if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] return encoded_layers, pooled_output class BertModelIncr(BertModel): def __init__(self, config): super(BertModelIncr, self).__init__(config) def forward(self, input_ids, token_type_ids, position_ids, attention_mask, output_all_encoded_layers=True, prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, task_idx=None): extended_attention_mask = self.get_extended_attention_mask( input_ids, token_type_ids, attention_mask) embedding_output = self.embeddings( input_ids, token_type_ids, position_ids, task_idx=task_idx) encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv, seg_ids=token_type_ids) sequence_output = encoded_layers[-1] pooled_output = self.pooler(sequence_output) if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] return embedding_output, encoded_layers, pooled_output class BertForPreTraining(PreTrainedBertModel): """BERT model with pre-training heads. This module comprises the BERT model followed by the two pre-training heads: - the masked language modeling head, and - the next sentence classification head. Params: config: a BertConfig class instance with the configuration to build a new model. Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`) `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size] `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] with indices selected in [0, 1]. 0 => next sentence is the continuation, 1 => next sentence is a random sentence. Outputs: if `masked_lm_labels` and `next_sentence_label` are not `None`: Outputs the total_loss which is the sum of the masked language modeling loss and the next sentence classification loss. if `masked_lm_labels` or `next_sentence_label` is `None`: Outputs a tuple comprising - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and - the next sentence classification logits of shape [batch_size, 2]. Example usage: ```python # Already been converted into WordPiece token ids input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) model = BertForPreTraining(config) masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) ``` """ def __init__(self, config): super(BertForPreTraining, self).__init__(config) self.bert = BertModel(config) self.cls = BertPreTrainingHeads( config, self.bert.embeddings.word_embeddings.weight) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, mask_qkv=None, task_idx=None): sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) prediction_scores, seq_relationship_score = self.cls( sequence_output, pooled_output) if masked_lm_labels is not None and next_sentence_label is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) masked_lm_loss = loss_fct( prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) next_sentence_loss = loss_fct( seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) total_loss = masked_lm_loss + next_sentence_loss return total_loss else: return prediction_scores, seq_relationship_score class BertPreTrainingPairTransform(nn.Module): def __init__(self, config): super(BertPreTrainingPairTransform, self).__init__() self.dense = nn.Linear(config.hidden_size*2, config.hidden_size) self.transform_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) else config.hidden_act # self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) def forward(self, pair_x, pair_y): hidden_states = torch.cat([pair_x, pair_y], dim=-1) hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) # hidden_states = self.LayerNorm(hidden_states) return hidden_states class BertPreTrainingPairRel(nn.Module): def __init__(self, config, num_rel=0): super(BertPreTrainingPairRel, self).__init__() self.R_xy = BertPreTrainingPairTransform(config) self.rel_emb = nn.Embedding(num_rel, config.hidden_size) def forward(self, pair_x, pair_y, pair_r, pair_pos_neg_mask): # (batch, num_pair, hidden) xy = self.R_xy(pair_x, pair_y) r = self.rel_emb(pair_r) _batch, _num_pair, _hidden = xy.size() pair_score = (xy * r).sum(-1) # torch.bmm(xy.view(-1, 1, _hidden),r.view(-1, _hidden, 1)).view(_batch, _num_pair) # .mul_(-1.0): objective to loss return F.logsigmoid(pair_score * pair_pos_neg_mask.type_as(pair_score)).mul_(-1.0) class BertForPreTrainingLossMask(PreTrainedBertModel): """refer to BertForPreTraining""" def __init__(self, config, num_labels=2, num_rel=0, num_sentlvl_labels=0, no_nsp=False): super(BertForPreTrainingLossMask, self).__init__(config) self.bert = BertModel(config) self.cls = BertPreTrainingHeads( config, self.bert.embeddings.word_embeddings.weight, num_labels=num_labels) self.num_sentlvl_labels = num_sentlvl_labels self.cls2 = None if self.num_sentlvl_labels > 0: self.secondary_pred_proj = nn.Embedding( num_sentlvl_labels, config.hidden_size) self.cls2 = BertPreTrainingHeads( config, self.secondary_pred_proj.weight, num_labels=num_sentlvl_labels) self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none') if no_nsp: self.crit_next_sent = None else: self.crit_next_sent = nn.CrossEntropyLoss(ignore_index=-1) self.num_labels = num_labels self.num_rel = num_rel if self.num_rel > 0: self.crit_pair_rel = BertPreTrainingPairRel( config, num_rel=num_rel) if hasattr(config, 'label_smoothing') and config.label_smoothing: self.crit_mask_lm_smoothed = LabelSmoothingLoss( config.label_smoothing, config.vocab_size, ignore_index=0, reduction='none') else: self.crit_mask_lm_smoothed = None self.apply(self.init_bert_weights) self.bert.rescale_some_parameters() def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, masked_pos=None, masked_weights=None, task_idx=None, pair_x=None, pair_x_mask=None, pair_y=None, pair_y_mask=None, pair_r=None, pair_pos_neg_mask=None, pair_loss_mask=None, masked_pos_2=None, masked_weights_2=None, masked_labels_2=None, num_tokens_a=None, num_tokens_b=None, mask_qkv=None): if token_type_ids is None and attention_mask is None: task_0 = (task_idx == 0) task_1 = (task_idx == 1) task_2 = (task_idx == 2) task_3 = (task_idx == 3) sequence_length = input_ids.shape[-1] index_matrix = torch.arange(sequence_length).view( 1, sequence_length).to(input_ids.device) num_tokens = num_tokens_a + num_tokens_b base_mask = (index_matrix < num_tokens.view(-1, 1) ).type_as(input_ids) segment_a_mask = ( index_matrix < num_tokens_a.view(-1, 1)).type_as(input_ids) token_type_ids = ( task_idx + 1 + task_3.type_as(task_idx)).view(-1, 1) * base_mask token_type_ids = token_type_ids - segment_a_mask * \ (task_0 | task_3).type_as(segment_a_mask).view(-1, 1) index_matrix = index_matrix.view(1, 1, sequence_length) index_matrix_t = index_matrix.view(1, sequence_length, 1) tril = index_matrix <= index_matrix_t attention_mask_task_0 = ( index_matrix < num_tokens.view(-1, 1, 1)) & (index_matrix_t < num_tokens.view(-1, 1, 1)) attention_mask_task_1 = tril & attention_mask_task_0 attention_mask_task_2 = torch.transpose( tril, dim0=-2, dim1=-1) & attention_mask_task_0 attention_mask_task_3 = ( (index_matrix < num_tokens_a.view(-1, 1, 1)) | tril) & attention_mask_task_0 attention_mask = (attention_mask_task_0 & task_0.view(-1, 1, 1)) | \ (attention_mask_task_1 & task_1.view(-1, 1, 1)) | \ (attention_mask_task_2 & task_2.view(-1, 1, 1)) | \ (attention_mask_task_3 & task_3.view(-1, 1, 1)) attention_mask = attention_mask.type_as(input_ids) sequence_output, pooled_output = self.bert( input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) def gather_seq_out_by_pos(seq, pos): return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1))) def gather_seq_out_by_pos_average(seq, pos, mask): # pos/mask: (batch, num_pair, max_token_num) batch_size, max_token_num = pos.size(0), pos.size(-1) # (batch, num_pair, max_token_num, seq.size(-1)) pos_vec = torch.gather(seq, 1, pos.view(batch_size, -1).unsqueeze( 2).expand(-1, -1, seq.size(-1))).view(batch_size, -1, max_token_num, seq.size(-1)) # (batch, num_pair, seq.size(-1)) mask = mask.type_as(pos_vec) pos_vec_masked_sum = ( pos_vec * mask.unsqueeze(3).expand_as(pos_vec)).sum(2) return pos_vec_masked_sum / mask.sum(2, keepdim=True).expand_as(pos_vec_masked_sum) def loss_mask_and_normalize(loss, mask): mask = mask.type_as(loss) loss = loss * mask denominator = torch.sum(mask) + 1e-5 return (loss / denominator).sum() if masked_lm_labels is None: if masked_pos is None: prediction_scores, seq_relationship_score = self.cls( sequence_output, pooled_output, task_idx=task_idx) else: sequence_output_masked = gather_seq_out_by_pos( sequence_output, masked_pos) prediction_scores, seq_relationship_score = self.cls( sequence_output_masked, pooled_output, task_idx=task_idx) return prediction_scores, seq_relationship_score # masked lm sequence_output_masked = gather_seq_out_by_pos( sequence_output, masked_pos) prediction_scores_masked, seq_relationship_score = self.cls( sequence_output_masked, pooled_output, task_idx=task_idx) if self.crit_mask_lm_smoothed: masked_lm_loss = self.crit_mask_lm_smoothed( F.log_softmax(prediction_scores_masked.float(), dim=-1), masked_lm_labels) else: masked_lm_loss = self.crit_mask_lm( prediction_scores_masked.transpose(1, 2).float(), masked_lm_labels) masked_lm_loss = loss_mask_and_normalize( masked_lm_loss.float(), masked_weights) # next sentence if self.crit_next_sent is None or next_sentence_label is None: next_sentence_loss = 0.0 else: next_sentence_loss = self.crit_next_sent( seq_relationship_score.view(-1, self.num_labels).float(), next_sentence_label.view(-1)) if self.cls2 is not None and masked_pos_2 is not None: sequence_output_masked_2 = gather_seq_out_by_pos( sequence_output, masked_pos_2) prediction_scores_masked_2, _ = self.cls2( sequence_output_masked_2, None) masked_lm_loss_2 = self.crit_mask_lm( prediction_scores_masked_2.transpose(1, 2).float(), masked_labels_2) masked_lm_loss_2 = loss_mask_and_normalize( masked_lm_loss_2.float(), masked_weights_2) masked_lm_loss = masked_lm_loss + masked_lm_loss_2 if pair_x is None or pair_y is None or pair_r is None or pair_pos_neg_mask is None or pair_loss_mask is None: return masked_lm_loss, next_sentence_loss # pair and relation if pair_x_mask is None or pair_y_mask is None: pair_x_output_masked = gather_seq_out_by_pos( sequence_output, pair_x) pair_y_output_masked = gather_seq_out_by_pos( sequence_output, pair_y) else: pair_x_output_masked = gather_seq_out_by_pos_average( sequence_output, pair_x, pair_x_mask) pair_y_output_masked = gather_seq_out_by_pos_average( sequence_output, pair_y, pair_y_mask) pair_loss = self.crit_pair_rel( pair_x_output_masked, pair_y_output_masked, pair_r, pair_pos_neg_mask) pair_loss = loss_mask_and_normalize( pair_loss.float(), pair_loss_mask) return masked_lm_loss, next_sentence_loss, pair_loss class BertForExtractiveSummarization(PreTrainedBertModel): """refer to BertForPreTraining""" def __init__(self, config): super(BertForExtractiveSummarization, self).__init__(config) self.bert = BertModel(config) self.secondary_pred_proj = nn.Embedding(2, config.hidden_size) self.cls2 = BertPreTrainingHeads( config, self.secondary_pred_proj.weight, num_labels=2) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_pos_2=None, masked_weights_2=None, task_idx=None, mask_qkv=None): sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) def gather_seq_out_by_pos(seq, pos): return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1))) sequence_output_masked_2 = gather_seq_out_by_pos( sequence_output, masked_pos_2) prediction_scores_masked_2, _ = self.cls2( sequence_output_masked_2, None, task_idx=task_idx) predicted_probs = torch.nn.functional.softmax( prediction_scores_masked_2, dim=-1) return predicted_probs, masked_pos_2, masked_weights_2 class BertForSeq2SeqDecoder(PreTrainedBertModel): """refer to BertForPreTraining""" def __init__(self, config, mask_word_id=0, num_labels=2, num_rel=0, search_beam_size=1, length_penalty=1.0, eos_id=0, sos_id=0, forbid_duplicate_ngrams=False, forbid_ignore_set=None, not_predict_set=None, ngram_size=3, min_len=0, mode="s2s", pos_shift=False): super(BertForSeq2SeqDecoder, self).__init__(config) self.bert = BertModelIncr(config) self.cls = BertPreTrainingHeads( config, self.bert.embeddings.word_embeddings.weight, num_labels=num_labels) self.apply(self.init_bert_weights) self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none') self.crit_next_sent = nn.CrossEntropyLoss(ignore_index=-1) self.mask_word_id = mask_word_id self.num_labels = num_labels self.num_rel = num_rel if self.num_rel > 0: self.crit_pair_rel = BertPreTrainingPairRel( config, num_rel=num_rel) self.search_beam_size = search_beam_size self.length_penalty = length_penalty self.eos_id = eos_id self.sos_id = sos_id self.forbid_duplicate_ngrams = forbid_duplicate_ngrams self.forbid_ignore_set = forbid_ignore_set self.not_predict_set = not_predict_set self.ngram_size = ngram_size self.min_len = min_len assert mode in ("s2s", "l2r") self.mode = mode self.pos_shift = pos_shift def forward(self, input_ids, token_type_ids, position_ids, attention_mask, task_idx=None, mask_qkv=None): if self.search_beam_size > 1: return self.beam_search(input_ids, token_type_ids, position_ids, attention_mask, task_idx=task_idx, mask_qkv=mask_qkv) input_shape = list(input_ids.size()) batch_size = input_shape[0] input_length = input_shape[1] output_shape = list(token_type_ids.size()) output_length = output_shape[1] output_ids = [] prev_embedding = None prev_encoded_layers = None curr_ids = input_ids mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id) next_pos = input_length if self.pos_shift: sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id) while next_pos < output_length: curr_length = list(curr_ids.size())[1] if self.pos_shift: if next_pos == input_length: x_input_ids = torch.cat((curr_ids, sos_ids), dim=1) start_pos = 0 else: x_input_ids = curr_ids start_pos = next_pos else: start_pos = next_pos - curr_length x_input_ids = torch.cat((curr_ids, mask_ids), dim=1) curr_token_type_ids = token_type_ids[:, start_pos:next_pos+1] curr_attention_mask = attention_mask[:, start_pos:next_pos+1, :next_pos+1] curr_position_ids = position_ids[:, start_pos:next_pos+1] new_embedding, new_encoded_layers, _ = \ self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask, output_all_encoded_layers=True, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv) last_hidden = new_encoded_layers[-1][:, -1:, :] prediction_scores, _ = self.cls( last_hidden, None, task_idx=task_idx) if self.not_predict_set: for token_id in self.not_predict_set: prediction_scores[:, :, token_id].fill_(-10000.0) _, max_ids = torch.max(prediction_scores, dim=-1) output_ids.append(max_ids) if self.pos_shift: if prev_embedding is None: prev_embedding = new_embedding else: prev_embedding = torch.cat( (prev_embedding, new_embedding), dim=1) if prev_encoded_layers is None: prev_encoded_layers = [x for x in new_encoded_layers] else: prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip( prev_encoded_layers, new_encoded_layers)] else: if prev_embedding is None: prev_embedding = new_embedding[:, :-1, :] else: prev_embedding = torch.cat( (prev_embedding, new_embedding[:, :-1, :]), dim=1) if prev_encoded_layers is None: prev_encoded_layers = [x[:, :-1, :] for x in new_encoded_layers] else: prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1) for x in zip(prev_encoded_layers, new_encoded_layers)] curr_ids = max_ids next_pos += 1 return torch.cat(output_ids, dim=1) def beam_search(self, input_ids, token_type_ids, position_ids, attention_mask, task_idx=None, mask_qkv=None): input_shape = list(input_ids.size()) batch_size = input_shape[0] input_length = input_shape[1] output_shape = list(token_type_ids.size()) output_length = output_shape[1] output_ids = [] prev_embedding = None prev_encoded_layers = None curr_ids = input_ids mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id) next_pos = input_length if self.pos_shift: sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id) K = self.search_beam_size total_scores = [] beam_masks = [] step_ids = [] step_back_ptrs = [] partial_seqs = [] forbid_word_mask = None buf_matrix = None while next_pos < output_length: curr_length = list(curr_ids.size())[1] if self.pos_shift: if next_pos == input_length: x_input_ids = torch.cat((curr_ids, sos_ids), dim=1) start_pos = 0 else: x_input_ids = curr_ids start_pos = next_pos else: start_pos = next_pos - curr_length x_input_ids = torch.cat((curr_ids, mask_ids), dim=1) curr_token_type_ids = token_type_ids[:, start_pos:next_pos + 1] curr_attention_mask = attention_mask[:, start_pos:next_pos + 1, :next_pos + 1] curr_position_ids = position_ids[:, start_pos:next_pos + 1] new_embedding, new_encoded_layers, _ = \ self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask, output_all_encoded_layers=True, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv) last_hidden = new_encoded_layers[-1][:, -1:, :] prediction_scores, _ = self.cls( last_hidden, None, task_idx=task_idx) log_scores = torch.nn.functional.log_softmax( prediction_scores, dim=-1) if forbid_word_mask is not None: log_scores += (forbid_word_mask * -10000.0) if self.min_len and (next_pos-input_length+1 <= self.min_len): log_scores[:, :, self.eos_id].fill_(-10000.0) if self.not_predict_set: for token_id in self.not_predict_set: log_scores[:, :, token_id].fill_(-10000.0) kk_scores, kk_ids = torch.topk(log_scores, k=K) if len(total_scores) == 0: k_ids = torch.reshape(kk_ids, [batch_size, K]) back_ptrs = torch.zeros(batch_size, K, dtype=torch.long) k_scores = torch.reshape(kk_scores, [batch_size, K]) else: last_eos = torch.reshape( beam_masks[-1], [batch_size * K, 1, 1]) last_seq_scores = torch.reshape( total_scores[-1], [batch_size * K, 1, 1]) kk_scores += last_eos * (-10000.0) + last_seq_scores kk_scores = torch.reshape(kk_scores, [batch_size, K * K]) k_scores, k_ids = torch.topk(kk_scores, k=K) back_ptrs = torch.div(k_ids, K) kk_ids = torch.reshape(kk_ids, [batch_size, K * K]) k_ids = torch.gather(kk_ids, 1, k_ids) step_back_ptrs.append(back_ptrs) step_ids.append(k_ids) beam_masks.append(torch.eq(k_ids, self.eos_id).float()) total_scores.append(k_scores) def first_expand(x): input_shape = list(x.size()) expanded_shape = input_shape[:1] + [1] + input_shape[1:] x = torch.reshape(x, expanded_shape) repeat_count = [1, K] + [1] * (len(input_shape) - 1) x = x.repeat(*repeat_count) x = torch.reshape(x, [input_shape[0] * K] + input_shape[1:]) return x def select_beam_items(x, ids): id_shape = list(ids.size()) id_rank = len(id_shape) assert len(id_shape) == 2 x_shape = list(x.size()) x = torch.reshape(x, [batch_size, K] + x_shape[1:]) x_rank = len(x_shape) + 1 assert x_rank >= 2 if id_rank < x_rank: ids = torch.reshape( ids, id_shape + [1] * (x_rank - id_rank)) ids = ids.expand(id_shape + x_shape[1:]) y = torch.gather(x, 1, ids) y = torch.reshape(y, x_shape) return y is_first = (prev_embedding is None) if self.pos_shift: if prev_embedding is None: prev_embedding = first_expand(new_embedding) else: prev_embedding = torch.cat( (prev_embedding, new_embedding), dim=1) prev_embedding = select_beam_items( prev_embedding, back_ptrs) if prev_encoded_layers is None: prev_encoded_layers = [first_expand( x) for x in new_encoded_layers] else: prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip( prev_encoded_layers, new_encoded_layers)] prev_encoded_layers = [select_beam_items( x, back_ptrs) for x in prev_encoded_layers] else: if prev_embedding is None: prev_embedding = first_expand(new_embedding[:, :-1, :]) else: prev_embedding = torch.cat( (prev_embedding, new_embedding[:, :-1, :]), dim=1) prev_embedding = select_beam_items( prev_embedding, back_ptrs) if prev_encoded_layers is None: prev_encoded_layers = [first_expand( x[:, :-1, :]) for x in new_encoded_layers] else: prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1) for x in zip(prev_encoded_layers, new_encoded_layers)] prev_encoded_layers = [select_beam_items( x, back_ptrs) for x in prev_encoded_layers] curr_ids = torch.reshape(k_ids, [batch_size * K, 1]) if is_first: token_type_ids = first_expand(token_type_ids) position_ids = first_expand(position_ids) attention_mask = first_expand(attention_mask) mask_ids = first_expand(mask_ids) if mask_qkv is not None: mask_qkv = first_expand(mask_qkv) if self.forbid_duplicate_ngrams: wids = step_ids[-1].tolist() ptrs = step_back_ptrs[-1].tolist() if is_first: partial_seqs = [] for b in range(batch_size): for k in range(K): partial_seqs.append([wids[b][k]]) else: new_partial_seqs = [] for b in range(batch_size): for k in range(K): new_partial_seqs.append( partial_seqs[ptrs[b][k] + b * K] + [wids[b][k]]) partial_seqs = new_partial_seqs def get_dup_ngram_candidates(seq, n): cands = set() if len(seq) < n: return [] tail = seq[-(n-1):] if self.forbid_ignore_set and any(tk in self.forbid_ignore_set for tk in tail): return [] for i in range(len(seq) - (n - 1)): mismatch = False for j in range(n - 1): if tail[j] != seq[i + j]: mismatch = True break if (not mismatch) and not(self.forbid_ignore_set and (seq[i + n - 1] in self.forbid_ignore_set)): cands.add(seq[i + n - 1]) return list(sorted(cands)) if len(partial_seqs[0]) >= self.ngram_size: dup_cands = [] for seq in partial_seqs: dup_cands.append( get_dup_ngram_candidates(seq, self.ngram_size)) if max(len(x) for x in dup_cands) > 0: if buf_matrix is None: vocab_size = list(log_scores.size())[-1] buf_matrix = np.zeros( (batch_size * K, vocab_size), dtype=float) else: buf_matrix.fill(0) for bk, cands in enumerate(dup_cands): for i, wid in enumerate(cands): buf_matrix[bk, wid] = 1.0 forbid_word_mask = torch.tensor( buf_matrix, dtype=log_scores.dtype) forbid_word_mask = torch.reshape( forbid_word_mask, [batch_size * K, 1, vocab_size]).cuda() else: forbid_word_mask = None next_pos += 1 # [(batch, beam)] total_scores = [x.tolist() for x in total_scores] step_ids = [x.tolist() for x in step_ids] step_back_ptrs = [x.tolist() for x in step_back_ptrs] # back tracking traces = {'pred_seq': [], 'scores': [], 'wids': [], 'ptrs': []} for b in range(batch_size): # [(beam,)] scores = [x[b] for x in total_scores] wids_list = [x[b] for x in step_ids] ptrs = [x[b] for x in step_back_ptrs] traces['scores'].append(scores) traces['wids'].append(wids_list) traces['ptrs'].append(ptrs) # first we need to find the eos frame where all symbols are eos # any frames after the eos frame are invalid last_frame_id = len(scores) - 1 for i, wids in enumerate(wids_list): if all(wid == self.eos_id for wid in wids): last_frame_id = i break max_score = -math.inf frame_id = -1 pos_in_frame = -1 for fid in range(last_frame_id + 1): for i, wid in enumerate(wids_list[fid]): if wid == self.eos_id or fid == last_frame_id: s = scores[fid][i] if self.length_penalty > 0: s /= math.pow((5 + fid + 1) / 6.0, self.length_penalty) if s > max_score: max_score = s frame_id = fid pos_in_frame = i if frame_id == -1: traces['pred_seq'].append([0]) else: seq = [wids_list[frame_id][pos_in_frame]] for fid in range(frame_id, 0, -1): pos_in_frame = ptrs[fid][pos_in_frame] seq.append(wids_list[fid - 1][pos_in_frame]) seq.reverse() traces['pred_seq'].append(seq) def _pad_sequence(sequences, max_len, padding_value=0): trailing_dims = sequences[0].size()[1:] out_dims = (len(sequences), max_len) + trailing_dims out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) for i, tensor in enumerate(sequences): length = tensor.size(0) # use index notation to prevent duplicate references to the tensor out_tensor[i, :length, ...] = tensor return out_tensor # convert to tensors for DataParallel for k in ('pred_seq', 'scores', 'wids', 'ptrs'): ts_list = traces[k] if not isinstance(ts_list[0], torch.Tensor): dt = torch.float if k == 'scores' else torch.long ts_list = [torch.tensor(it, dtype=dt) for it in ts_list] traces[k] = _pad_sequence( ts_list, output_length, padding_value=0).to(input_ids.device) return traces class BertForMaskedLM(PreTrainedBertModel): """BERT model with the masked language modeling head. This module comprises the BERT model followed by the masked language modeling head. Params: config: a BertConfig class instance with the configuration to build a new model. Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`) `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size] Outputs: if `masked_lm_labels` is `None`: Outputs the masked language modeling loss. if `masked_lm_labels` is `None`: Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. Example usage: ```python # Already been converted into WordPiece token ids input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) model = BertForMaskedLM(config) masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) ``` """ def __init__(self, config): super(BertForMaskedLM, self).__init__(config) self.bert = BertModel(config) self.cls = BertOnlyMLMHead( config, self.bert.embeddings.word_embeddings.weight) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, mask_qkv=None, task_idx=None): sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) prediction_scores = self.cls(sequence_output) if masked_lm_labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) masked_lm_loss = loss_fct( prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) return masked_lm_loss else: return prediction_scores class BertForNextSentencePrediction(PreTrainedBertModel): """BERT model with next sentence prediction head. This module comprises the BERT model followed by the next sentence classification head. Params: config: a BertConfig class instance with the configuration to build a new model. Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`) `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] with indices selected in [0, 1]. 0 => next sentence is the continuation, 1 => next sentence is a random sentence. Outputs: if `next_sentence_label` is not `None`: Outputs the total_loss which is the sum of the masked language modeling loss and the next sentence classification loss. if `next_sentence_label` is `None`: Outputs the next sentence classification logits of shape [batch_size, 2]. Example usage: ```python # Already been converted into WordPiece token ids input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) model = BertForNextSentencePrediction(config) seq_relationship_logits = model(input_ids, token_type_ids, input_mask) ``` """ def __init__(self, config): super(BertForNextSentencePrediction, self).__init__(config) self.bert = BertModel(config) self.cls = BertOnlyNSPHead(config) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, mask_qkv=None, task_idx=None): _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) seq_relationship_score = self.cls(pooled_output) if next_sentence_label is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) next_sentence_loss = loss_fct( seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) return next_sentence_loss else: return seq_relationship_score class BertForSequenceClassification(PreTrainedBertModel): """BERT model for classification. This module is composed of the BERT model with a linear layer on top of the pooled output. Params: `config`: a BertConfig class instance with the configuration to build a new model. `num_labels`: the number of classes for the classifier. Default = 2. Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`) `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_labels]. Outputs: if `labels` is not `None`: Outputs the CrossEntropy classification loss of the output with the labels. if `labels` is `None`: Outputs the classification logits of shape [batch_size, num_labels]. Example usage: ```python # Already been converted into WordPiece token ids input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) num_labels = 2 model = BertForSequenceClassification(config, num_labels) logits = model(input_ids, token_type_ids, input_mask) ``` """ def __init__(self, config, num_labels=2): super(BertForSequenceClassification, self).__init__(config) self.num_labels = num_labels self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, mask_qkv=None, task_idx=None): _, pooled_output = self.bert( input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) if labels is not None: if labels.dtype == torch.long: loss_fct = CrossEntropyLoss() loss = loss_fct( logits.view(-1, self.num_labels), labels.view(-1)) elif labels.dtype == torch.half or labels.dtype == torch.float: loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: print('unkown labels.dtype') loss = None return loss else: return logits class BertForMultipleChoice(PreTrainedBertModel): """BERT model for multiple choice tasks. This module is composed of the BERT model with a linear layer on top of the pooled output. Params: `config`: a BertConfig class instance with the configuration to build a new model. `num_choices`: the number of classes for the classifier. Default = 2. Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`) `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_choices]. Outputs: if `labels` is not `None`: Outputs the CrossEntropy classification loss of the output with the labels. if `labels` is `None`: Outputs the classification logits of shape [batch_size, num_labels]. Example usage: ```python # Already been converted into WordPiece token ids input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) num_choices = 2 model = BertForMultipleChoice(config, num_choices) logits = model(input_ids, token_type_ids, input_mask) ``` """ def __init__(self, config, num_choices=2): super(BertForMultipleChoice, self).__init__(config) self.num_choices = num_choices self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, 1) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, mask_qkv=None, task_idx=None): flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) _, pooled_output = self.bert( flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) reshaped_logits = logits.view(-1, self.num_choices) if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) return loss else: return reshaped_logits class BertForTokenClassification(PreTrainedBertModel): """BERT model for token-level classification. This module is composed of the BERT model with a linear layer on top of the full hidden state of the last layer. Params: `config`: a BertConfig class instance with the configuration to build a new model. `num_labels`: the number of classes for the classifier. Default = 2. Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`) `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_labels]. Outputs: if `labels` is not `None`: Outputs the CrossEntropy classification loss of the output with the labels. if `labels` is `None`: Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. Example usage: ```python # Already been converted into WordPiece token ids input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) num_labels = 2 model = BertForTokenClassification(config, num_labels) logits = model(input_ids, token_type_ids, input_mask) ``` """ def __init__(self, config, num_labels=2): super(BertForTokenClassification, self).__init__(config) self.num_labels = num_labels self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, mask_qkv=None, task_idx=None): sequence_output, _ = self.bert( input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) if labels is not None: loss_fct = CrossEntropyLoss() # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels)[active_loss] active_labels = labels.view(-1)[active_loss] loss = loss_fct(active_logits, active_labels) else: loss = loss_fct( logits.view(-1, self.num_labels), labels.view(-1)) return loss else: return logits class BertForQuestionAnswering(PreTrainedBertModel): """BERT model for Question Answering (span extraction). This module is composed of the BERT model with a linear layer on top of the sequence output that computes start_logits and end_logits Params: `config`: either - a BertConfig class instance with the configuration to build a new model, or - a str with the name of a pre-trained model to load selected in the list of: . `bert-base-uncased` . `bert-large-uncased` . `bert-base-cased` . `bert-base-multilingual` . `bert-base-chinese` The pre-trained model will be downloaded and cached if needed. Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`) `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. Positions are clamped to the length of the sequence and position outside of the sequence are not taken into account for computing the loss. `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. Positions are clamped to the length of the sequence and position outside of the sequence are not taken into account for computing the loss. Outputs: if `start_positions` and `end_positions` are not `None`: Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. if `start_positions` or `end_positions` is `None`: Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end position tokens of shape [batch_size, sequence_length]. Example usage: ```python # Already been converted into WordPiece token ids input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) model = BertForQuestionAnswering(config) start_logits, end_logits = model(input_ids, token_type_ids, input_mask) ``` """ def __init__(self, config): super(BertForQuestionAnswering, self).__init__(config) self.bert = BertModel(config) # self.dropout = nn.Dropout(config.hidden_dropout_prob) self.qa_outputs = nn.Linear(config.hidden_size, 2) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, task_idx=None): sequence_output, _ = self.bert( input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, task_idx=task_idx) logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 return total_loss else: return start_logits, end_logits