# coding=utf-8 # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional, Tuple import numpy as np import flax import flax.linen as nn import jax import jax.numpy as jnp import jaxlib.xla_extension as jax_xla from flax.core.frozen_dict import FrozenDict from flax.linen.attention import dot_product_attention_weights from jax import lax from jax.random import PRNGKey from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward from ...modeling_flax_outputs import ( FlaxBaseModelOutput, FlaxMaskedLMOutput, FlaxMultipleChoiceModelOutput, FlaxQuestionAnsweringModelOutput, FlaxSequenceClassifierOutput, FlaxTokenClassifierOutput, ) from ...modeling_flax_utils import ( ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, append_replace_return_docstrings, overwrite_call_docstring, ) from ...utils import logging from .configuration_electra import ElectraConfig logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "google/electra-small-discriminator" _CONFIG_FOR_DOC = "ElectraConfig" _TOKENIZER_FOR_DOC = "ElectraTokenizer" @flax.struct.dataclass class FlaxElectraForPreTrainingOutput(ModelOutput): """ Output type of :class:`~transformers.ElectraForPreTraining`. Args: logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ logits: jax_xla.DeviceArray = None hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None ELECTRA_START_DOCSTRING = r""" This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading, saving and converting weights from PyTorch models) This model is also a Flax Linen `flax.nn.Module `__ subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - `Just-In-Time (JIT) compilation `__ - `Automatic Differentiation `__ - `Vectorization `__ - `Parallelization `__ Parameters: config (:class:`~transformers.ElectraConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. """ ELECTRA_INPUTS_DOCSTRING = r""" Args: input_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`~transformers.ElectraTokenizer`. See :meth:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for details. `What are input IDs? <../glossary.html#input-ids>`__ attention_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ token_type_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: - 0 corresponds to a `sentence A` token, - 1 corresponds to a `sentence B` token. `What are token type IDs? <../glossary.html#token-type-ids>`__ position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. return_dict (:obj:`bool`, `optional`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. """ class FlaxElectraEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.word_embeddings = nn.Embed( self.config.vocab_size, self.config.embedding_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), dtype=self.dtype, ) self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.embedding_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), dtype=self.dtype, ) self.token_type_embeddings = nn.Embed( self.config.type_vocab_size, self.config.embedding_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings.__call__ def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): # Embed inputs_embeds = self.word_embeddings(input_ids.astype("i4")) position_embeds = self.position_embeddings(position_ids.astype("i4")) token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) # Sum all embeddings hidden_states = inputs_embeds + token_type_embeddings + position_embeds # Layer Norm hidden_states = self.LayerNorm(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) return hidden_states # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Electra class FlaxElectraSelfAttention(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): if self.config.hidden_size % self.config.num_attention_heads != 0: raise ValueError( "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\ : {self.config.num_attention_heads}" ) self.query = nn.Dense( self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), ) self.key = nn.Dense( self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), ) self.value = nn.Dense( self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), ) def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): head_dim = self.config.hidden_size // self.config.num_attention_heads query_states = self.query(hidden_states).reshape( hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) ) value_states = self.value(hidden_states).reshape( hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) ) key_states = self.key(hidden_states).reshape( hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) ) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e10).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.config.attention_probs_dropout_prob > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( query_states, key_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attention_probs_dropout_prob, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Electra class FlaxElectraSelfOutput(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.dense = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) def __call__(self, hidden_states, input_tensor, deterministic: bool = True): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Electra class FlaxElectraAttention(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.self = FlaxElectraSelfAttention(self.config, dtype=self.dtype) self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype) def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) attn_outputs = self.self( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions ) attn_output = attn_outputs[0] hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) outputs = (hidden_states,) if output_attentions: outputs += (attn_outputs[1],) return outputs # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Electra class FlaxElectraIntermediate(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.dense = nn.Dense( self.config.intermediate_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), dtype=self.dtype, ) self.activation = ACT2FN[self.config.hidden_act] def __call__(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.activation(hidden_states) return hidden_states # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Electra class FlaxElectraOutput(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.dense = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__(self, hidden_states, attention_output, deterministic: bool = True): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.LayerNorm(hidden_states + attention_output) return hidden_states # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Electra class FlaxElectraLayer(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.attention = FlaxElectraAttention(self.config, dtype=self.dtype) self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype) self.output = FlaxElectraOutput(self.config, dtype=self.dtype) def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False): attention_outputs = self.attention( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions ) attention_output = attention_outputs[0] hidden_states = self.intermediate(attention_output) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) outputs = (hidden_states,) if output_attentions: outputs += (attention_outputs[1],) return outputs # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Electra class FlaxElectraLayerCollection(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.layers = [ FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] def __call__( self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for i, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = layer( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions ) hidden_states = layer_outputs[0] if output_attentions: all_attentions += (layer_outputs[1],) if output_hidden_states: all_hidden_states += (hidden_states,) outputs = (hidden_states,) if not return_dict: return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions ) # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Electra class FlaxElectraEncoder(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.layer = FlaxElectraLayerCollection(self.config, dtype=self.dtype) def __call__( self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): return self.layer( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) class FlaxElectraGeneratorPredictions(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype) def __call__(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = ACT2FN[self.config.hidden_act](hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class FlaxElectraDiscriminatorPredictions(nn.Module): """Prediction module for the discriminator, made up of two dense layers.""" config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) self.dense_prediction = nn.Dense(1, dtype=self.dtype) def __call__(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = ACT2FN[self.config.hidden_act](hidden_states) hidden_states = self.dense_prediction(hidden_states).squeeze(-1) return hidden_states class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = ElectraConfig base_model_prefix = "electra" module_class: nn.Module = None def __init__( self, config: ElectraConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") token_type_ids = jnp.zeros_like(input_ids) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) attention_mask = jnp.ones_like(input_ids) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[ "params" ] @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, params: dict = None, dropout_rng: PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict # init input tensors if not passed if token_type_ids is None: token_type_ids = jnp.ones_like(input_ids) if position_ids is None: position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) if attention_mask is None: attention_mask = jnp.ones_like(input_ids) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng return self.module.apply( {"params": params or self.params}, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), not train, output_attentions, output_hidden_states, return_dict, rngs=rngs, ) class FlaxElectraModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype) if self.config.embedding_size != self.config.hidden_size: self.embeddings_project = nn.Dense(self.config.hidden_size) self.encoder = FlaxElectraEncoder(self.config, dtype=self.dtype) def __call__( self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): embeddings = self.embeddings( input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic ) if hasattr(self, "embeddings_project"): embeddings = self.embeddings_project(embeddings) return self.encoder( embeddings, attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) @add_start_docstrings( "The bare Electra Model transformer outputting raw hidden-states without any specific head on top.", ELECTRA_START_DOCSTRING, ) class FlaxElectraModel(FlaxElectraPreTrainedModel): module_class = FlaxElectraModule append_call_sample_docstring( FlaxElectraModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC ) class FlaxElectraTiedDense(nn.Module): embedding_size: int dtype: jnp.dtype = jnp.float32 precision = None bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros def setup(self): bias = self.param("bias", self.bias_init, (self.embedding_size,)) self.bias = jnp.asarray(bias, dtype=self.dtype) def __call__(self, x, kernel): y = lax.dot_general( x, kernel, (((x.ndim - 1,), (0,)), ((), ())), precision=self.precision, ) return y + self.bias class FlaxElectraForMaskedLMModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config) if self.config.tie_word_embeddings: self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) else: self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype) def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): outputs = self.electra( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] prediction_scores = self.generator_predictions(hidden_states) if self.config.tie_word_embeddings: shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"] prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T) else: prediction_scores = self.generator_lm_head(prediction_scores) if not return_dict: return (prediction_scores,) + outputs[1:] return FlaxMaskedLMOutput( logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @add_start_docstrings("""Electra Model with a `language modeling` head on top. """, ELECTRA_START_DOCSTRING) class FlaxElectraForMaskedLM(FlaxElectraPreTrainedModel): module_class = FlaxElectraForMaskedLMModule append_call_sample_docstring( FlaxElectraForMaskedLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC ) class FlaxElectraForPreTrainingModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype) def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): # Model outputs = self.electra( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.discriminator_predictions(hidden_states) if not return_dict: return (logits,) + outputs[1:] return FlaxElectraForPreTrainingOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @add_start_docstrings( """ Electra model with a binary classification head on top as used during pretraining for identifying generated tokens. It is recommended to load the discriminator checkpoint into that model. """, ELECTRA_START_DOCSTRING, ) class FlaxElectraForPreTraining(FlaxElectraPreTrainedModel): module_class = FlaxElectraForPreTrainingModule FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING = """ Returns: Example:: >>> from transformers import ElectraTokenizer, FlaxElectraForPreTraining >>> tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator') >>> model = FlaxElectraForPreTraining.from_pretrained('google/electra-small-discriminator') >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") >>> outputs = model(**inputs) >>> prediction_logits = outputs.logits """ overwrite_call_docstring( FlaxElectraForPreTraining, ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING, ) append_replace_return_docstrings( FlaxElectraForPreTraining, output_type=FlaxElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC ) class FlaxElectraForTokenClassificationModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.dropout = nn.Dropout(self.config.hidden_dropout_prob) self.classifier = nn.Dense(self.config.num_labels) def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): # Model outputs = self.electra( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] hidden_states = self.dropout(hidden_states, deterministic=deterministic) logits = self.classifier(hidden_states) if not return_dict: return (logits,) + outputs[1:] return FlaxTokenClassifierOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @add_start_docstrings( """ Electra model with a token classification head on top. Both the discriminator and generator may be loaded into this model. """, ELECTRA_START_DOCSTRING, ) class FlaxElectraForTokenClassification(FlaxElectraPreTrainedModel): module_class = FlaxElectraForTokenClassificationModule append_call_sample_docstring( FlaxElectraForTokenClassification, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxTokenClassifierOutput, _CONFIG_FOR_DOC, ) def identity(x, **kwargs): return x class FlaxElectraSequenceSummary(nn.Module): r""" Compute a single vector summary of a sequence hidden states. Args: config (:class:`~transformers.PretrainedConfig`): The config used by the model. Relevant arguments in the config class of the model are (refer to the actual config class of your model for the default values it uses): - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction. - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`). - **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the output, another string or :obj:`None` will add no activation. - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and activation. - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and activation. """ config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.summary = identity if hasattr(self.config, "summary_use_proj") and self.config.summary_use_proj: if ( hasattr(self.config, "summary_proj_to_labels") and self.config.summary_proj_to_labels and self.config.num_labels > 0 ): num_classes = self.config.num_labels else: num_classes = self.config.hidden_size self.summary = nn.Dense(num_classes, dtype=self.dtype) activation_string = getattr(self.config, "summary_activation", None) self.activation = ACT2FN[activation_string] if activation_string else lambda x: x self.first_dropout = identity if hasattr(self.config, "summary_first_dropout") and self.config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(self.config.summary_first_dropout) self.last_dropout = identity if hasattr(self.config, "summary_last_dropout") and self.config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(self.config.summary_last_dropout) def __call__(self, hidden_states, cls_index=None, deterministic: bool = True): """ Compute a single vector summary of a sequence hidden states. Args: hidden_states (:obj:`jnp.array` of shape :obj:`[batch_size, seq_len, hidden_size]`): The hidden states of the last layer. cls_index (:obj:`jnp.array` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`): Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification token. Returns: :obj:`jnp.array`: The summary of the sequence hidden states. """ # NOTE: this doest "first" type summary always output = hidden_states[:, 0] output = self.first_dropout(output, deterministic=deterministic) output = self.summary(output) output = self.activation(output) output = self.last_dropout(output, deterministic=deterministic) return output class FlaxElectraForMultipleChoiceModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype) self.classifier = nn.Dense(1, dtype=self.dtype) def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): num_choices = input_ids.shape[1] input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None # Model outputs = self.electra( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] pooled_output = self.sequence_summary(hidden_states, deterministic=deterministic) logits = self.classifier(pooled_output) reshaped_logits = logits.reshape(-1, num_choices) if not return_dict: return (reshaped_logits,) + outputs[1:] return FlaxMultipleChoiceModelOutput( logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @add_start_docstrings( """ ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, ELECTRA_START_DOCSTRING, ) class FlaxElectraForMultipleChoice(FlaxElectraPreTrainedModel): module_class = FlaxElectraForMultipleChoiceModule # adapt docstring slightly for FlaxElectraForMultipleChoice overwrite_call_docstring( FlaxElectraForMultipleChoice, ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") ) append_call_sample_docstring( FlaxElectraForMultipleChoice, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMultipleChoiceModelOutput, _CONFIG_FOR_DOC, ) class FlaxElectraForQuestionAnsweringModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): # Model outputs = self.electra( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.qa_outputs(hidden_states) start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) if not return_dict: return (start_logits, end_logits) + outputs[1:] return FlaxQuestionAnsweringModelOutput( start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @add_start_docstrings( """ ELECTRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, ELECTRA_START_DOCSTRING, ) class FlaxElectraForQuestionAnswering(FlaxElectraPreTrainedModel): module_class = FlaxElectraForQuestionAnsweringModule append_call_sample_docstring( FlaxElectraForQuestionAnswering, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxQuestionAnsweringModelOutput, _CONFIG_FOR_DOC, ) class FlaxElectraClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) self.dropout = nn.Dropout(self.config.hidden_dropout_prob) self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__(self, hidden_states, deterministic: bool = True): x = hidden_states[:, 0, :] # take token (equiv. to [CLS]) x = self.dropout(x, deterministic=deterministic) x = self.dense(x) x = ACT2FN["gelu"](x) # although BERT uses tanh here, it seems Electra authors used gelu x = self.dropout(x, deterministic=deterministic) x = self.out_proj(x) return x class FlaxElectraForSequenceClassificationModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype) def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): # Model outputs = self.electra( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.classifier(hidden_states, deterministic=deterministic) if not return_dict: return (logits,) + outputs[1:] return FlaxSequenceClassifierOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @add_start_docstrings( """ Electra Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, ELECTRA_START_DOCSTRING, ) class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel): module_class = FlaxElectraForSequenceClassificationModule append_call_sample_docstring( FlaxElectraForSequenceClassification, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSequenceClassifierOutput, _CONFIG_FOR_DOC, )