diff --git "a/modeling_hf_nomic_bert.py" "b/modeling_hf_nomic_bert.py" --- "a/modeling_hf_nomic_bert.py" +++ "b/modeling_hf_nomic_bert.py" @@ -1,2351 +1,2354 @@ -# Copyright (c) 2022, Tri Dao. -# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation. -# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py -# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py - -import collections -import logging - -# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py -import math -import os -import re -from collections import OrderedDict -from functools import partial -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange, repeat -from safetensors.torch import load_file as safe_load_file -from torch.nn.modules.utils import _pair -from transformers import GPT2Config, PreTrainedModel, ViTConfig, ViTModel -from transformers.models.bert.modeling_bert import ( - BaseModelOutputWithPoolingAndCrossAttentions, - MaskedLMOutput, - SequenceClassifierOutput, -) -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPast, - BaseModelOutputWithPooling, - MaskedLMOutput, - MultipleChoiceModelOutput, - QuestionAnsweringModelOutput, - SequenceClassifierOutput, - ModelOutput, - TokenClassifierOutput, -) -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME -from transformers.utils.hub import cached_file, get_checkpoint_shard_files - -from .configuration_hf_nomic_bert import NomicBertConfig - -try: - from torch.nn.functional import scaled_dot_product_attention -except ImportError: - scaled_dot_product_attention = None - -logger = logging.getLogger(__name__) - - -# adapted from flash attention, added safe serialization option for hf models -def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None): - # If not fp32, then we don't want to load directly to the GPU - mapped_device = "cpu" if dtype not in [torch.float32, None] else device - is_sharded = False - load_safe = False - resolved_archive_file = None - - weights_path = os.path.join(model_name, WEIGHTS_NAME) - weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME) - safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME) - safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME) - - if os.path.isfile(weights_path): - resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) - elif os.path.isfile(weights_index_path): - resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False) - is_sharded = True - elif os.path.isfile(safe_weights_path): - resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) - load_safe = True - elif os.path.isfile(safe_weights_index_path): - resolved_archive_file = cached_file( - model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False - ) - is_sharded = True - load_safe = True - else: # Try loading from HF hub instead of from local files - resolved_archive_file = None - for weight_name in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]: - resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False) - if resolved_archive_file is not None: - if weight_name in [SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]: - load_safe = True - if weight_name in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]: - is_sharded = True - break - - if resolved_archive_file is None: - raise EnvironmentError(f"Model name {model_name} was not found.") - - if load_safe: - loader = partial(safe_load_file, device=mapped_device) - else: - loader = partial(torch.load, map_location=mapped_device) - - if is_sharded: - # resolved_archive_file becomes a list of files that point to the different - # checkpoint shards in this case. - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(model_name, resolved_archive_file) - state_dict = {} - for sharded_file in resolved_archive_file: - state_dict.update(loader(sharded_file)) - else: - state_dict = loader(resolved_archive_file) - # Convert dtype before moving to GPU to save memory - if dtype is not None: - state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} - state_dict = {k: v.to(device=device) for k, v in state_dict.items()} - return state_dict - - -def filter_shapes(state_dict, model): - """ - Filters the state dict to match the current model shape. - """ - filtered_state_dict = {} - for key, value in state_dict.items(): - if key in model.state_dict(): - if value.shape == model.state_dict()[key].shape: - filtered_state_dict[key] = value - return filtered_state_dict - - -def remap_bert_state_dict( - state_dict, - config, - remove_bert=False, - remove_cls_weights=False, - add_pooling_layer=False, -): - """ - Map the state_dict of a Huggingface BERT model to be flash_attn compatible. - """ - - def add_bert_prefix(key): - # prepend bert. to the key - if key.startswith("bert.") or key.startswith("cls."): - return key - return f"bert.{key}" - - state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items()) - - # LayerNorm - def key_mapping_ln_gamma_beta(key): - key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key) - key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key) - return key - - state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()) - - # Layers - def key_mapping_layers(key): - return re.sub(r"^bert.encoder.layer\.", "bert.encoder.layers.", key) - - state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) - - # LayerNorm - def key_mapping_ln(key): - key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key) - key = re.sub( - r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)", - r"bert.encoder.layers.\1.norm1.\2", - key, - ) - key = re.sub( - r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)", - r"bert.encoder.layers.\1.norm2.\2", - key, - ) - key = re.sub( - r"^cls.predictions.transform.LayerNorm.(weight|bias)", - r"cls.predictions.transform.layer_norm.\1", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - - # MLP - def key_mapping_mlp(key): - key = re.sub( - r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)", - r"bert.encoder.layers.\1.mlp.fc1.\2", - key, - ) - key = re.sub( - r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)", - r"bert.encoder.layers.\1.mlp.fc2.\2", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - - # Attention - last_layer_subset = getattr(config, "last_layer_subset", False) - for d in range(config.num_hidden_layers): - if f"bert.encoder.layers.{d}.attention.self.query.weight" not in state_dict: - continue - Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight") - Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight") - Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight") - bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias") - bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias") - bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias") - if not (last_layer_subset and d == config.num_hidden_layers - 1): - state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) - state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0) - else: - state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq - state_dict[f"bert.encoder.layers.{d}.attn.Wkv.weight"] = torch.cat([Wk, Wv], dim=0) - state_dict[f"bert.encoder.layers.{d}.attn.Wq.bias"] = bq - state_dict[f"bert.encoder.layers.{d}.attn.Wkv.bias"] = torch.cat([bk, bv], dim=0) - - def key_mapping_attn(key): - return re.sub( - r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)", - r"bert.encoder.layers.\1.attn.out_proj.\2", - key, - ) - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - - def key_mapping_decoder_bias(key): - return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key) - - # remove nsp weights, we don't use - state_dict.pop("cls.seq_relationship.weight", None) - state_dict.pop("cls.seq_relationship.bias", None) - state_dict.pop("bert.embeddings.position_ids", None) - - state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items()) - - if remove_cls_weights: - cls_weights = [ - "cls.predictions.decoder.bias", - "cls.predictions.transform.dense.weight", - "cls.predictions.transform.dense.bias", - "cls.predictions.transform.layer_norm.weight", - "cls.predictions.transform.layer_norm.bias", - "cls.predictions.decoder.weight", - ] - for weight in cls_weights: - state_dict.pop(weight, None) - - # Word embedding - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - if pad_vocab_size_multiple > 1: - word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"] - state_dict["bert.embeddings.word_embeddings.weight"] = F.pad( - word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0]) - ) - if not remove_cls_weights: - decoder_weight = state_dict["cls.predictions.decoder.weight"] - state_dict["cls.predictions.decoder.weight"] = F.pad( - decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0]) - ) - # If the vocab was padded, we want to set the decoder bias for those padded indices to be - # strongly negative (i.e. the decoder shouldn't predict those indices). - # TD [2022-05-09]: I don't think it affects the MLPerf training. - if "cls.predictions.decoder.bias" in state_dict: - decoder_bias = state_dict["cls.predictions.decoder.bias"] - state_dict["cls.predictions.decoder.bias"] = F.pad( - decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0 - ) - - if add_pooling_layer is False: - pooler_weights = [ - "bert.pooler.dense.weight", - "bert.pooler.dense.bias", - ] - for key in pooler_weights: - state_dict.pop(key, None) - - if remove_bert: - - def remove_bert_prefix(key): - key = re.sub(r"^bert.", "", key) - return key - - state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items()) - - return state_dict - - -def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - print( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor - - -def trunc_normal_tf_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): - r"""Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \leq \text{mean} \leq b`. - - NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the - bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 - and the result is subsquently scaled and shifted by the mean and std args. - - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - Examples: - >>> w = torch.empty(3, 5) - >>> nn.init.trunc_normal_(w) - """ - with torch.no_grad(): - _trunc_normal_(tensor, 0, 1.0, a, b) - tensor.mul_(std).add_(mean) - return tensor - - -class NomicBertPreTrainedModel(PreTrainedModel): - """An abstract class to handle weights initialization and - a simple interface for dowloading and loading pretrained models. - """ - - config_class = NomicBertConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Block"] - _skip_keys_device_placement = "past_key_values" - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config) - if not isinstance(config, GPT2Config): - raise ValueError( - "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. " - "To create a model from a Google pretrained model use " - "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( - self.__class__.__name__, self.__class__.__name__ - ) - ) - self.config = config - - @classmethod - def from_pretrained(cls, model_name, config=None, *inputs, **kwargs): - """ - Instantiate a NomicBertPreTrainedModel from a pre-trained model file or a pytorch state dict. - Download and cache the pre-trained model file if needed. - - Params: - pretrained_model_name_or_path: either: - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `pytorch_model.bin` a PyTorch dump of a NomicBertForPretraining instance - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `model.chkpt` a TensorFlow checkpoint - *inputs, **kwargs: additional input for the specific NomicBert class - (ex: num_labels for NomicBertForSequenceClassification) - """ - # Instantiate model. - if config is None: - config = cls.config_class.from_pretrained(model_name) - remove_cls = cls != NomicBertForPreTraining - remove_bert_prefix = cls not in [NomicBertForPreTraining, NomicBertForSequenceClassification, NomicBertForTokenClassification, NomicBertForMultipleChoice, NomicBertForQuestionAnswering] - ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False) - num_labels = kwargs.pop("num_labels", None) - rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None) - strict = kwargs.pop("strict", True) - dtype = kwargs.pop("torch_dtype", None) - if rotary_scaling_factor: - config.rotary_scaling_factor = rotary_scaling_factor - - if config.n_positions <= 0 and config.rotary_emb_fraction > 0: - config.n_positions = 2048 - if num_labels: - config.num_labels = num_labels - - if "add_pooling_layer" in kwargs: - model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer")) - else: - if cls == NomicBertModel: - model = cls(config, *inputs, add_pooling_layer=False) - else: - model = cls(config, *inputs) - - if dtype is not None: - model = model.to(dtype=dtype) - # TODO: fix this - # Assuming we know what we're doing when loading from disk - # Prob a bad assumption but i'm tired and want to train this asap - if os.path.exists(model_name): - model_path = f"{model_name}/pytorch_model.bin" - if os.path.exists(model_path): - state_dict = torch.load(f"{model_name}/pytorch_model.bin") - else: - model_path = f"{model_name}/model.safetensors" - if not os.path.exists(model_path): - raise ValueError(f"Model path {model_path} not found") - state_dict = safe_load_file(model_path) - - if ignore_mismatched_shapes: - state_dict = filter_shapes(state_dict, model) - load_return = model.load_state_dict(state_dict, strict=False) - else: - # TODO: can probably check config class and see if we need to remap from a bert model - state_dict = state_dict_from_pretrained(model_name, dtype=dtype) - state_dict = remap_bert_state_dict( - state_dict, - config, - remove_bert=remove_bert_prefix, - remove_cls_weights=remove_cls, - add_pooling_layer=getattr(config, "add_pooling_layer", False), - ) - if ignore_mismatched_shapes: - state_dict = filter_shapes(state_dict, model) - - load_return = model.load_state_dict(state_dict, strict=strict) - logger.warning(load_return) - return model - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, NomicBertEncoder): - module.gradient_checkpointing = value - - -# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748 -def _init_weights(module, initializer_range=0.02): - if isinstance(module, nn.Linear): - nn.init.normal_(module.weight, std=initializer_range) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, std=initializer_range) - if module.padding_idx is not None: - nn.init.zeros_(module.weight[module.padding_idx]) - - -def _ntuple(n): - def parse(x): - if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): - return tuple(x) - return tuple(repeat(x, n)) - - return parse - - -to_1tuple = _ntuple(1) -to_2tuple = _ntuple(2) -to_3tuple = _ntuple(3) -to_4tuple = _ntuple(4) -to_ntuple = _ntuple - - -def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): - """ - Create 2D sin/cos positional embeddings. - - Args: - embed_dim (`int`): - Embedding dimension. - grid_size (`int`): - The grid height and width. - add_cls_token (`bool`, *optional*, defaults to `False`): - Whether or not to add a classification (CLS) token. - - Returns: - (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the - position embeddings (with or without classification token) - """ - grid_h = np.arange(grid_size, dtype=np.float32) - - grid_w = np.arange(grid_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size, grid_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if add_cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be even") - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) - """ - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be even") - - omega = np.arange(embed_dim // 2, dtype=float) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - -def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]: - """generate N-D grid in dimension order. - - The ndgrid function is like meshgrid except that the order of the first two input arguments are switched. - - That is, the statement - [X1,X2,X3] = ndgrid(x1,x2,x3) - - produces the same result as - - [X2,X1,X3] = meshgrid(x2,x1,x3) - - This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make - torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy'). - - """ - try: - return torch.meshgrid(*tensors, indexing='ij') - except TypeError: - # old PyTorch < 1.10 will follow this path as it does not have indexing arg, - # the old behaviour of meshgrid was 'ij' - return torch.meshgrid(*tensors) - - -def build_fourier_pos_embed( - feat_shape: List[int], - bands: Optional[torch.Tensor] = None, - num_bands: int = 64, - max_res: int = 224, - temperature: float = 10000.0, - linear_bands: bool = False, - include_grid: bool = False, - in_pixels: bool = True, - ref_feat_shape: Optional[List[int]] = None, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, -) -> List[torch.Tensor]: - """ - - Args: - feat_shape: Feature shape for embedding. - bands: Pre-calculated frequency bands. - num_bands: Number of frequency bands (determines output dim). - max_res: Maximum resolution for pixel based freq. - temperature: Temperature for non-pixel freq. - linear_bands: Linear band spacing for pixel based freq. - include_grid: Include the spatial grid in output. - in_pixels: Output in pixel freq. - ref_feat_shape: Reference feature shape for resize / fine-tune. - dtype: Output dtype. - device: Output device. - - Returns: - - """ - if bands is None: - if in_pixels: - bands = pixel_freq_bands( - num_bands, - float(max_res), - linear_bands=linear_bands, - device=device, - ) - else: - bands = freq_bands( - num_bands, - temperature=temperature, - step=1, - device=device, - ) - else: - if device is None: - device = bands.device - if dtype is None: - dtype = bands.dtype - - if in_pixels: - t = [torch.linspace(-1.0, 1.0, steps=s, device=device, dtype=torch.float32) for s in feat_shape] - else: - t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape] - - if ref_feat_shape is not None: - # eva's scheme for resizing rope embeddings (ref shape = pretrain) - t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)] - - grid = torch.stack(ndgrid(t), dim=-1) - grid = grid.unsqueeze(-1) - pos = grid * bands - - pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype) - out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos] - return out - - -def build_rotary_pos_embed( - feat_shape: List[int], - bands: Optional[torch.Tensor] = None, - dim: int = 64, - max_res: int = 224, - temperature: float = 10000.0, - linear_bands: bool = False, - in_pixels: bool = True, - ref_feat_shape: Optional[List[int]] = None, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, -): - """ - - Args: - feat_shape: Spatial shape of the target tensor for embedding. - bands: Optional pre-generated frequency bands - dim: Output dimension of embedding tensor. - max_res: Maximum resolution for pixel mode. - temperature: Temperature (inv freq) for non-pixel mode - linear_bands: Linearly (instead of log) spaced bands for pixel mode - in_pixels: Pixel vs language (inv freq) mode. - dtype: Output dtype. - device: Output device. - - Returns: - - """ - sin_emb, cos_emb = build_fourier_pos_embed( - feat_shape, - bands=bands, - num_bands=dim // 4, - max_res=max_res, - temperature=temperature, - linear_bands=linear_bands, - in_pixels=in_pixels, - ref_feat_shape=ref_feat_shape, - device=device, - dtype=dtype, - ) - num_spatial_dim = 1 - # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks - for x in feat_shape: - num_spatial_dim *= x - sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) - cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) - return sin_emb, cos_emb - - -def freq_bands( - num_bands: int, - temperature: float = 10000.0, - step: int = 2, - device: Optional[torch.device] = None, -) -> torch.Tensor: - exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands - bands = 1.0 / (temperature**exp) - return bands - - -def pixel_freq_bands( - num_bands: int, - max_freq: float = 224.0, - linear_bands: bool = True, - device: Optional[torch.device] = None, -): - if linear_bands: - bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device) - else: - bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device) - return bands * torch.pi - - -def rot(x): - return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) - - -def apply_rot_embed_cat(x: torch.Tensor, emb): - sin_emb, cos_emb = emb.tensor_split(2, -1) - if sin_emb.ndim == 3: - return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x) - return x * cos_emb + rot(x) * sin_emb - - -# taken from https://github.com/huggingface/pytorch-image-models/blob/cb0e4391beedcc5ac3ae4bce16561b95c326f32c/timm/layers/pos_embed_sincos.py#L363 -class NomicVisionRotaryEmbeddingCat(nn.Module): - """Rotary position embedding w/ concatenatd sin & cos - - The following impl/resources were referenced for this impl: - * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py - * https://blog.eleuther.ai/rotary-embeddings/ - """ - - def __init__( - self, - dim, - max_res=224, - temperature=10000, - in_pixels=True, - linear_bands: bool = False, - feat_shape: Optional[List[int]] = None, - ref_feat_shape: Optional[List[int]] = None, - ): - super().__init__() - self.dim = dim - self.max_res = max_res - self.temperature = temperature - self.in_pixels = in_pixels - self.feat_shape = feat_shape - self.ref_feat_shape = ref_feat_shape - - if feat_shape is None: - # only cache bands - if in_pixels: - bands = pixel_freq_bands( - dim // 4, - float(max_res), - linear_bands=linear_bands, - ) - else: - bands = freq_bands( - dim // 4, - temperature=temperature, - step=1, - ) - self.register_buffer( - 'bands', - bands, - persistent=False, - ) - self.pos_embed = None - else: - # cache full sin/cos embeddings if shape provided up front - embeds = build_rotary_pos_embed( - feat_shape=feat_shape, - dim=dim, - max_res=max_res, - linear_bands=linear_bands, - in_pixels=in_pixels, - ref_feat_shape=self.ref_feat_shape, - ) - self.bands = None - self.register_buffer( - 'pos_embed', - torch.cat(embeds, -1), - persistent=False, - ) - - def get_embed(self, shape: Optional[List[int]] = None): - if self.bands is not None and shape is not None: - # rebuild embeddings every call, use if target shape changes - embeds = build_rotary_pos_embed( - shape, - self.bands, - in_pixels=self.in_pixels, - ref_feat_shape=self.ref_feat_shape, - ) - return torch.cat(embeds, -1) - elif self.pos_embed is not None: - return self.pos_embed - else: - assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands" - - def forward(self, x): - # assuming channel-first tensor where spatial dim are >= 2 - pos_embed = self.get_embed(x.shape[2:]) - return apply_rot_embed_cat(x, pos_embed) - - -class NomicVisionPatchEmbeddings(nn.Module): - def __init__( - self, - config, - ): - super().__init__() - img_size = _pair(config.img_size) - patch_size = _pair(config.patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - - self.proj = nn.Linear( - config.num_channels * patch_size[0] * patch_size[1], config.n_embd, bias=config.patch_embed_bias - ) - - self.learned_pos_embedding = False - self.sinusoidal_pos_embedding = False - self.no_embed_class = getattr(config, "no_embed_class", False) - - self.cls_token = ( - nn.Parameter(torch.zeros(1, 1, config.n_embd)) if not getattr(config, "no_cls_token", False) else None - ) - if config.learned_pos_embedding: - # this is the default in DINO - self.learned_pos_embedding = True - # hack for timm dinov2 with registers - num_patches = self.num_patches if getattr(config, "register_tokens", 0) > 0 else self.num_patches + 1 - self.pos_embed = ( - nn.Parameter(torch.randn(1, num_patches, config.n_embd) * 0.02) - if getattr(config, "use_pos_embed", True) - else None - ) - elif getattr(config, "sinusoidal_pos_embedding", False): - self.sinusoidal_pos_embedding = True - if getattr(config, "use_pos_embed", True): - self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, config.n_embd), requires_grad=False) - pos_embed = get_2d_sincos_pos_embed(config.n_embd, self.grid_size[0], add_cls_token=True) - self.pos_embed.data.copy_(torch.from_numpy(pos_embed).to(self.pos_embed)) - else: - self.pos_embed = None - else: - self.pos_embed = ( - nn.Parameter(torch.randn(1, self.num_patches + 1, config.n_embd) * 0.02) - if getattr(config, "use_pos_embed", True) - else None - ) - - if getattr(config, "register_tokens", 0) > 0: - self.reg_token = nn.Parameter(torch.randn(1, config.register_tokens, config.n_embd) * 0.02) - else: - self.reg_token = None - - if config.mask_token: - self.mask_token = nn.Parameter(torch.zeros(1, config.n_embd)) - - self.patch_dropout = nn.Identity() - - if getattr(config, "use_rotary_pos_emb", False): - ref_feat_shape = getattr(config, "ref_feat_shape", None) - ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None - self.rope = NomicVisionRotaryEmbeddingCat( - config.n_embd // config.n_head, - in_pixels=False, - feat_shape=self.grid_size, - ref_feat_shape=ref_feat_shape, - ) - else: - self.rope = None - - def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: - """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. - - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 - """ - num_patches = embeddings.shape[1] - 1 - num_positions = self.pos_embed.shape[1] - 1 - if num_patches == num_positions and height == width: - return self.pos_embed - class_pos_embed = self.pos_embed[:, 0] - patch_pos_embed = self.pos_embed[:, 1:] - dim = embeddings.shape[-1] - height = height // self.patch_size[0] - width = width // self.patch_size[1] - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - height, width = height + 0.1, width + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) - patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed, - scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)), - mode="bicubic", - align_corners=False, - ) - if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: - raise ValueError("Width or height does not match with the interpolated position embeddings") - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) - - def forward(self, x): - # deepspeed case where the input is in fp32 - if x.dtype != self.proj.weight.dtype: - x = x.to(dtype=self.proj.weight.dtype) - - _, _, height, width = x.shape - x = self.proj( - rearrange( - x, - "b c (h p1) (w p2) -> b h w (c p1 p2)", - p1=self.patch_size[0], - p2=self.patch_size[1], - ) - ) - embeddings = rearrange(x, "b h w c -> b (h w) c") - - to_cat = [] - if self.cls_token is not None: - if self.sinusoidal_pos_embedding: - cls_token = self.cls_token + self.pos_embed[:, 0] - cls_token = cls_token.expand(embeddings.shape[0], -1, -1) - to_cat += [cls_token] - else: - cls_token = self.cls_token.expand(embeddings.shape[0], 1, -1) - to_cat += [cls_token] - - if self.reg_token is not None: - to_cat += [self.reg_token.expand(embeddings.shape[0], -1, -1)] - - rot_pos_embed = self.rope.get_embed() if self.rope is not None else None - - if self.no_embed_class: - if self.learned_pos_embedding: - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) - else: - if self.pos_embed is not None: - embeddings = embeddings + self.pos_embed - if to_cat: - embeddings = torch.cat(to_cat + [embeddings], dim=1) - else: - if to_cat: - embeddings = torch.cat(to_cat + [embeddings], dim=1) - if self.learned_pos_embedding: - if self.pos_embed is not None: - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) - else: - if self.pos_embed is not None: - embeddings = embeddings + self.pos_embed - - embeddings = self.patch_dropout(embeddings) - - return embeddings, rot_pos_embed - - -class NomicBertEmbeddings(nn.Module): - def __init__(self, config): - """ - If max_position_embeddings <= 0, there's no position embeddings - If type_vocab_size <= 0, there's no token type embeddings - """ - super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0 - self.type_vocab_size = config.type_vocab_size - if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0: - self.position_embeddings = nn.Embedding( - config.max_position_embeddings, - config.hidden_size, - ) - if self.type_vocab_size > 0: - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - - def forward(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None): - """ - input_ids: (batch, seqlen) - position_ids: (batch, seqlen) - token_type_ids: (batch, seqlen) - """ - if inputs_embeds is None: - embeddings = self.word_embeddings(input_ids) - else: - embeddings = inputs_embeds - batch_size, seqlen, _ = embeddings.shape - - if self.type_vocab_size > 0: - if token_type_ids is None: - token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=embeddings.device) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = embeddings + token_type_embeddings - - if self.max_position_embeddings > 0: - if position_ids is None: - position_ids = torch.arange(seqlen, dtype=torch.long, device=embeddings.device) - position_embeddings = self.position_embeddings(position_ids) - embeddings = embeddings + position_embeddings - return embeddings - - -class NomicBertMLP(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - activation=F.gelu, - bias1=True, - bias2=True, - return_residual=False, - fused_bias_fc=False, - ): - super().__init__() - out_features = out_features if out_features is not None else in_features - hidden_features = hidden_features if hidden_features is not None else in_features * 4 - self.return_residual = return_residual - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1) - approximate = "tanh" if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none" - self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2) - - def forward(self, x): - y = self.fc1(x) - y = self.activation(y) - y = self.fc2(y) - return y if not self.return_residual else (y, x) - - -class NomciBertGatedMLP(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - activation=F.sigmoid, - bias1=True, - bias2=True, - multiple_of=256, - return_residual=False, - fused_bias_fc=True, - device=None, - dtype=None, - norm_layer=False, - ): - super().__init__() - out_features = out_features if out_features is not None else in_features - hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3) - hidden_features = int((hidden_features + multiple_of - 1) // multiple_of * multiple_of) - self.return_residual = return_residual - - self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1) - self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1) - self.activation = activation - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2) - self.norm = nn.LayerNorm(hidden_features) if norm_layer else nn.Identity() - - def forward(self, x): - y = self.fc11(x) - gate = self.fc12(x) - if self.activation == F.sigmoid: # Special case for GLU - y = F.glu(torch.cat([y, gate], dim=-1), dim=-1) - else: - y = y * self.activation(gate) - - # eva uses layer norm after the activation - y = self.norm(y) - - y = self.fc2(y) - return y if not self.return_residual else (y, x) - - -def rotate_half(x, interleaved=False): - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) - - -def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos, sin = ( - cos[offset : offset + x.shape[1]], - sin[offset : offset + x.shape[1]], - ) - cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") - sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") - return torch.cat( - [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], - dim=-1, - ) - - -class NomicBertRotaryEmbedding(nn.Module): - def __init__( - self, - dim: int, - base=10000.0, - interleaved=False, - scale_base=None, - pos_idx_in_fp32=True, - device=None, - ): - """ - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, - otherwise they might be in lower precision. - This option was added because previously (before 2023-07-02), when we construct - the position indices, we use the dtype of self.inv_freq. In most cases this would - be fp32, but if the model is trained in pure bf16 (not mixed precision), then - self.inv_freq would be bf16, and the position indices are also in bf16. - Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the - embeddings for some positions will coincide. - To maintain compatibility with models previously trained in pure bf16, - we add this option. - """ - super().__init__() - self.dim = dim - self.base = float(base) - self.pos_idx_in_fp32 = pos_idx_in_fp32 - # Generate and save the inverse frequency buffer (non trainable) - inv_freq = self._compute_inv_freq(device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.interleaved = interleaved - self.scale_base = scale_base - scale = ( - (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) - if scale_base is not None - else None - ) - self.register_buffer("scale", scale, persistent=False) - - self._seq_len_cached = 0 - self._cos_cached = None - self._sin_cached = None - self._cos_k_cached = None - self._sin_k_cached = None - - def _compute_inv_freq(self, device=None): - return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) - - def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): - # Reset the tables if the sequence length has changed, - # if we're on a new device (possibly due to tracing for instance), - # or if we're switching from inference mode to training - if ( - seqlen > self._seq_len_cached - or self._cos_cached is None - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - or (self.training and self._cos_cached.is_inference()) - ): - self._seq_len_cached = seqlen - # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 - # And the output of arange can be quite large, so bf16 would lose a lot of precision. - # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. - if self.pos_idx_in_fp32: - t = torch.arange(seqlen, device=device, dtype=torch.float32) - # We want fp32 here as well since inv_freq will be multiplied with t, and the output - # will be large. Having it in bf16 will lose a lot of precision and cause the - # cos & sin output to change significantly. - # We want to recompute self.inv_freq if it was not loaded in fp32 - if self.inv_freq.dtype != torch.float32: - inv_freq = self._compute_inv_freq(device=device) - else: - inv_freq = self.inv_freq - else: - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - inv_freq = self.inv_freq - # Don't do einsum, it converts fp32 to fp16 under AMP - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - - def forward( - self, - qkv: torch.Tensor, - kv: Optional[torch.Tensor] = None, - seqlen_offset: Union[int, torch.Tensor] = 0, - max_seqlen: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, - else it's just q of shape (batch, seqlen, nheads, headdim) - kv: (batch, seqlen, 2, nheads, headdim) - seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. - Most commonly used in inference when we have KV cache. - If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one - should pass in max_seqlen, which will update the cos / sin cache up to that length. - Apply rotary embedding *inplace* to qkv and / or kv. - """ - seqlen = qkv.shape[1] - if seqlen > self._seq_len_cached: - self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype) - elif max_seqlen is not None: - self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) - elif isinstance(seqlen_offset, int): - self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) - - q_rot = apply_rotary_emb(qkv[:, :, 0], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved) - k_rot = apply_rotary_emb(qkv[:, :, 1], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved) - return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2) - - -class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding): - def __init__(self, rotary_scaling_factor, max_position_embeddings, **kwargs): - super().__init__(**kwargs) - self.rotary_scaling_factor = rotary_scaling_factor - self.max_position_embeddings = max_position_embeddings - - def _compute_inv_freq(self, base=None, device=None): - if base is None: - base = self.base - return 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) - - def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): - # Reset the tables if the sequence length has changed, - # if we're on a new device (possibly due to tracing for instance), - # or if we're switching from inference mode to training - if seqlen > self.max_position_embeddings: - base = self.base * ( - (self.rotary_scaling_factor * seqlen / self.max_position_embeddings) - (self.rotary_scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = self._compute_inv_freq(base=base, device=device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - if ( - seqlen > self._seq_len_cached - or self._cos_cached is None - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - or (self.training and self._cos_cached.is_inference()) - ): - self._seq_len_cached = seqlen - # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 - # And the output of arange can be quite large, so bf16 would lose a lot of precision. - # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. - if self.pos_idx_in_fp32: - t = torch.arange(seqlen, device=device, dtype=torch.float32) - # We want fp32 here as well since inv_freq will be multiplied with t, and the output - # will be large. Having it in bf16 will lose a lot of precision and cause the - # cos & sin output to change significantly. - # We want to recompute self.inv_freq if it was not loaded in fp32 - if self.inv_freq.dtype != torch.float32: - if seqlen > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - else: - base = self.base - inv_freq = self._compute_inv_freq(device=device, base=base) - else: - inv_freq = self.inv_freq - else: - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - inv_freq = self.inv_freq - # Don't do einsum, it converts fp32 to fp16 under AMP - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, inv_freq) - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - else: - power = ( - torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 - ) / self.scale_base - scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") - # We want the multiplication by scale to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) - - -class NomicBertAttention(nn.Module): - """Multi-head self-attention and cross-attention""" - - def __init__( - self, - config, - ) -> None: - """ - num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. - return_residual: whether to return the input x along with the output. This is for - performance reason: for post-norm architecture, returning the input allows us - to fuse the backward of nn.Linear with the residual connection. - """ - super().__init__() - self.embed_dim = config.n_embd - self.use_flash_attn = config.use_flash_attn - self.fused_bias_fc = config.fused_bias_fc - - self.num_heads = config.n_head - self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads - assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" - self.head_dim = self.embed_dim // self.num_heads - # we don't really support mqa / gqa for now - qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) - - self.register_buffer( - "norm_factor", - torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()), - persistent=False, - ) - - self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction - if self.rotary_emb_dim > 0: - if getattr(config, "rotary_scaling_factor", None): - self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding( - dim=self.rotary_emb_dim, - base=config.rotary_emb_base, - scale_base=config.rotary_emb_scale_base, - interleaved=config.rotary_emb_interleaved, - rotary_scaling_factor=config.rotary_scaling_factor, - max_position_embeddings=config.max_trained_positions, - ) - else: - self.rotary_emb = NomicBertRotaryEmbedding( - dim=self.rotary_emb_dim, - base=config.rotary_emb_base, - scale_base=config.rotary_emb_scale_base, - interleaved=config.rotary_emb_interleaved, - ) - # bug in xformers: https://github.com/facebookresearch/xformers/issues/841 - # uses the head dimension instead of the sequence dimension - self.rotary_head_dim = getattr(config, "rotary_head_dim", False) - - self.Wqkv = nn.Linear(self.embed_dim, qkv_dim, bias=config.qkv_proj_bias) - - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias) - self.causal = config.causal - self.drop = nn.Dropout(config.attn_pdrop) - self.num_prefix_tokens = max(getattr(config, "register_tokens", 1), 1) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - is_padded_inputs: Optional[bool] = True, - cu_seqlens: Optional[torch.Tensor] = None, - max_seq_len: Optional[int] = None, - rope: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - - has_layer_past = past_key_value is not None - - if has_layer_past: - past_key_value = past_key_value[0] - past_len = past_key_value[1] - else: - past_len = 0 - - qkv = self.Wqkv(hidden_states) - qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) - - past_key_value = (past_key_value, past_len + qkv.size(1)) if use_cache else None - - if self.rotary_emb_dim > 0: - if self.rotary_head_dim: - qkv = rearrange(qkv, "b s three h d -> b h three s d") - qkv = self.rotary_emb(qkv, seqlen_offset=past_len) - - if self.rotary_head_dim: - qkv = rearrange(qkv, "b h three s d -> b s three h d") - elif rope is not None: - q, k, v = qkv.permute(0, 3, 1, 2, 4).unbind(dim=-2) - q = torch.cat( - [q[:, :, : self.num_prefix_tokens], apply_rot_embed_cat(q[:, :, self.num_prefix_tokens :], rope)], dim=2 - ).type_as(q) - k = torch.cat( - [k[:, :, : self.num_prefix_tokens], apply_rot_embed_cat(k[:, :, self.num_prefix_tokens :], rope)], dim=2 - ).type_as(q) - - qkv = torch.stack([q, k, v], dim=-2) - qkv = rearrange(qkv, "b h s three d -> b s three h d") - - query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] - - query = query.permute(0, 2, 1, 3) - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) - if scaled_dot_product_attention is not None: - attn_output = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=self.drop.p, is_causal=False - ) - else: - attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor - if attention_mask is not None: - attention_scores = attention_scores + attention_mask - - attentions_probs = F.softmax(attention_scores, dim=-1) - attentions_probs = self.drop(attentions_probs) - - attn_output = torch.matmul(attentions_probs, value) - - attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)") - - attn_output = self.out_proj(attn_output) - - return attn_output - - -class NomicBertBlock(NomicBertPreTrainedModel): - def __init__( - self, - config, - ): - super().__init__(config=config) - self.prenorm = config.prenorm - self.fused_dropout_add_ln = config.fused_dropout_add_ln - - self.attn = NomicBertAttention(config) - activation = ( - F.sigmoid - if config.activation_function == "glu" - else (F.silu if config.activation_function == "swiglu" else F.gelu) - ) - if config.activation_function in ["glu", "swiglu", "geglu"]: - self.mlp = NomciBertGatedMLP( - config.n_embd, - hidden_features=config.n_inner, - bias1=config.mlp_fc1_bias, - bias2=config.mlp_fc2_bias, - activation=activation, - fused_bias_fc=config.fused_bias_fc, - norm_layer=getattr(config, "norm_mlp", False), - ) - else: - self.mlp = NomicBertMLP( - config.n_embd, - hidden_features=config.n_inner, - bias1=config.mlp_fc1_bias, - bias2=config.mlp_fc2_bias, - activation=activation, - fused_bias_fc=config.fused_bias_fc, - ) - - self.dropout1 = nn.Dropout(config.resid_pdrop) - self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.norm2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.dropout2 = nn.Dropout(config.resid_pdrop) - - def forward( - self, - hidden_states: torch.Tensor, - hidden_states2: torch.Tensor, - residual: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - is_padded_inputs: Optional[bool] = True, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seq_len: Optional[int] = None, - rope: Optional[torch.Tensor] = None, - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) - mixer_subset: for cross-attention only. If not None, will take a subset of x - before applying the query projection. Useful for e.g., ViT where we only care - about the CLS token in the last layer. - """ - if self.prenorm: - dropped = self.dropout1(hidden_states) - residual = (dropped + residual) if residual is not None else dropped - hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) - hidden_states = self.attn( - hidden_states, - attention_mask=attention_mask, - is_padded_inputs=is_padded_inputs, - cu_seqlens=cu_seqlens, - max_seq_len=max_seq_len, - rope=rope, - ) - - dropped = self.dropout2(hidden_states) - residual = (dropped + residual) if residual is not None else dropped - hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) - hidden_states = self.mlp(hidden_states) - - return hidden_states, None, residual - else: - assert residual is None - attn_outputs = self.attn( - hidden_states, - attention_mask=attention_mask, - is_padded_inputs=is_padded_inputs, - cu_seqlens=cu_seqlens, - max_seq_len=max_seq_len, - rope=rope, - ) - hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype)) - mlp_out = self.mlp(hidden_states) - - hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype)) - return hidden_states, None, None - - -class NomicBertEncoder(nn.Module): - def __init__(self, config: GPT2Config): - super().__init__() - self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)]) - self.gradient_checkpointing = False - self.config = config - - def forward( - self, - hidden_states: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - is_padded_inputs: Optional[bool] = True, - rope: Optional[torch.Tensor] = None, - ): - """If subset_mask is not None, we only want output for the subset of the sequence. - This means that we only compute the last layer output for these tokens. - subset_mask: (batch, seqlen), dtype=torch.bool - """ - hidden_states2 = None - residual = None - - for _, layer in enumerate(self.layers): - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs) - - return custom_forward - - hidden_states, hidden_states2, residual = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), - hidden_states, - hidden_states2, - residual, - attention_mask, - position_ids, - past_key_values, - is_padded_inputs, - output_attentions, - use_cache, - None, - None, - rope, - # if you freeze ANY layers, you need `use_reentrant=False` - # https://github.com/huggingface/transformers/issues/21381 - # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/7 - use_reentrant=False, - ) - - else: - hidden_states, hidden_states2, residual = layer( - hidden_states, - hidden_states2, - residual, - attention_mask, - position_ids, - None, - is_padded_inputs, - output_attentions, - use_cache, - rope=rope, - ) - return hidden_states - - -class NomicBertPooler(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.n_embd, config.n_embd) - self.activation = nn.Tanh() - - def forward(self, hidden_states, pool=True): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] if pool else hidden_states - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class NomicBertPredictionHeadTransform(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias) - approximate = "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none" - if config.activation_function == "swiglu": - self.transform_act_fn = F.silu - else: - self.transform_act_fn = nn.GELU(approximate=approximate) - - self.layer_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.layer_norm(hidden_states) - - return hidden_states - - -class NomicBertLMPredictionHead(nn.Module): - def __init__(self, config): - super().__init__() - - self.transform = NomicBertPredictionHeadTransform(config) - - self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=config.mlp_fc1_bias) - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - -class NomicBertPreTrainingHeads(nn.Module): - def __init__(self, config): - super().__init__() - self.predictions = NomicBertLMPredictionHead(config) - - def forward(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -class NomicBertModel(NomicBertPreTrainedModel): - def __init__(self, config: GPT2Config, add_pooling_layer=True): - super().__init__(config) - self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - if config.vocab_size % self.pad_vocab_size_multiple != 0: - config.vocab_size += self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple) - - assert config.activation_function in [ - "gelu", - "gelu_new", - "gelu_fast", - "gelu_pytorch_tanh", - "swiglu", - "geglu", - "glu", - ] - - self.embeddings = NomicBertEmbeddings(config) - self.emb_drop = nn.Dropout(config.resid_pdrop) - self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.encoder = NomicBertEncoder(config) - self.pooler = NomicBertPooler(config) if add_pooling_layer else None - - self.apply(partial(_init_weights, initializer_range=config.initializer_range)) - - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - token_type_ids=None, - return_dict=None, - matryoshka_dim=None, - inputs_embeds=None, - ): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - hidden_states = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - ) - hidden_states = self.emb_ln(hidden_states) - hidden_states = self.emb_drop(hidden_states) - - attention_mask = self.get_extended_attention_mask(attention_mask, hidden_states.shape[:-1]) - sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict) - - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if matryoshka_dim: - sequence_output = sequence_output[:, :matryoshka_dim] - - return BaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - ) - - -class NomicBertForPreTraining(NomicBertPreTrainedModel): - _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] - - def __init__(self, config: GPT2Config): - super().__init__(config) - - self.bert = NomicBertModel(config, add_pooling_layer=getattr(config, "add_pooling_layer", False)) - self.cls = NomicBertPreTrainingHeads(config) - self.mlm_loss = nn.CrossEntropyLoss() - - # Initialize weights and apply final processing - self.apply(partial(_init_weights, initializer_range=config.initializer_range)) - self.tie_weights() - - def tie_weights(self): - self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight - - def forward( - self, - input_ids, - position_ids=None, - token_type_ids=None, - attention_mask=None, - labels=None, - ): - """ - If labels are provided, they must be -100 for masked out tokens (as specified in the attention - mask). - Outputs: - if `labels` and `next_sentence_label` are not `None`: - Outputs the total_loss which is the sum of the masked language modeling loss and the next - sentence classification loss. - if `labels` or `next_sentence_label` is `None`: - Outputs a tuple comprising - - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and - - the next sentence classification logits of shape [batch_size, 2]. - - """ - outputs = self.bert( - input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask.bool() if attention_mask is not None else None, - ) - sequence_output, _ = outputs.last_hidden_state, outputs.pooler_output - - prediction_scores = self.cls(sequence_output) - - total_loss = None - if labels is not None: - masked_lm_loss = self.mlm_loss( - rearrange(prediction_scores, "... v -> (...) v"), - rearrange(labels, "... -> (...)"), - ) - total_loss = masked_lm_loss.float() - - return MaskedLMOutput( - loss=total_loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=None, - ) - - -class NomicBertForSequenceClassification(NomicBertPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.config = config - - self.bert = NomicBertModel(config) - classifier_dropout = getattr(config, "classifier_dropout", config.embd_pdrop) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.n_embd, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.bert( - input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask.bool() if attention_mask is not None else None, - ) - - pooled_output = outputs[1] - - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = nn.MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = nn.BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - -class NomicBertForMultipleChoice(NomicBertPreTrainedModel): - def __init__(self, config): - super().__init__(config) - - self.bert = NomicBertModel(config, add_pooling_layer=True) - classifier_dropout = ( - getattr(config, "classifier_dropout", config.resid_pdrop) - ) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, 1) - - # Initialize weights and apply final processing - self.post_init() - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - unpad_inputs: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., - num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See - `input_ids` above) - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] - - input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None - attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None - token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None - position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None - inputs_embeds = ( - inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None - else None - ) - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - ) - - pooled_output = outputs[1] - - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - reshaped_logits = logits.view(-1, num_choices) - - loss = None - if labels is not None: - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return MultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - -class NomicBertForTokenClassification(NomicBertPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - - self.bert = NomicBertModel(config, add_pooling_layer=False) - classifier_dropout = ( - getattr(config, "classifier_dropout", config.resid_pdrop) - ) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - ) - - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - loss = None - if labels is not None: - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - -class NomicBertForQuestionAnswering(NomicBertPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - - self.bert = NomicBertModel(config, add_pooling_layer=False) - self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - start_positions: Optional[torch.Tensor] = None, - end_positions: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return QuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - -def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config: - return GPT2Config( - n_embd=vit_config.hidden_size, - n_layer=vit_config.num_hidden_layers, - n_head=vit_config.num_attention_heads, - n_inner=vit_config.intermediate_size, - activation_function=vit_config.hidden_act, - vocab_size=0, # no vocab since using patches - n_positions=0, # No absolute position embedding - resid_pdrop=0.0, # No dropout - embd_pdrop=getattr(vit_config, "dropout", 0.0), - attn_pdrop=vit_config.attention_probs_dropout_prob, - layer_norm_epsilon=vit_config.layer_norm_eps, - initializer_range=vit_config.initializer_range, - bos_token_id=None, - eos_token_id=None, - # These are new arguments not in the original GPT2Config - drop_path_rate=0.0, - # Why is there double layer norm?? - prepre_layernom=False, - layer_scale=False, - layer_scale_init=None, - img_size=vit_config.image_size, - patch_size=vit_config.patch_size, - num_channels=vit_config.num_channels, - prenorm=True, - parallel_block=False, - parallel_block_tied_norm=False, - rotary_emb_fraction=0, - tie_word_embeddings=False, - fused_dropout_add_ln=True, - fused_bias_fc=True, - patch_embed_bias=True, - use_flash_attn=True, - qkv_proj_bias=True, - mlp_fc1_bias=getattr(vit_config, "mlp_fc1_bias", True), - mlp_fc2_bias=getattr(vit_config, "mlp_fc2_bias", True), - use_rms_norm=False, - causal=False, - hidden_features_scaling_factor=1.0, - mask_token=False, - learned_pos_embedding=False, - patch_dropout=0, - sinusoidal_pos_embedding=vit_config.model_type == "vit_mae", - ) - - -class NomicAttentionPooling(nn.Module): - def __init__(self, config): - super().__init__() - self.embed_dim = config.n_embd - self.use_flash_attn = config.use_flash_attn - self.fused_bias_fc = config.fused_bias_fc - - self.num_heads = config.n_head - self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads - assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" - self.head_dim = self.embed_dim // self.num_heads - # we don't really support mqa / gqa for now - kv_dim = 2 * self.head_dim * self.num_heads_kv - - self.register_buffer( - "norm_factor", - torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()), - persistent=False, - ) - - self.Wq = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias) - self.Wkv = nn.Linear(self.embed_dim, kv_dim, bias=config.qkv_proj_bias) - - self.latent = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) - - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias) - self.causal = config.causal - self.drop = nn.Dropout(config.attn_pdrop) - - def init_weights(self): - trunc_normal_tf_(self.latent, std=self.embed_dim**-0.5) - - def forward( - self, - kv, - attention_mask=None, - cu_seqlens_k=None, - max_seqlen_k=None, - is_padded_inputs: Optional[bool] = True, - output_attentions: bool = False, - ): - """Implements the multihead softmax attention. - Arguments - --------- - q: The tensor containing the query. (B, Sq, H, D) - kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) - causal: if passed, will override self.causal - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - max_seqlen: int. Maximum sequence length in the batch of q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_k: int. Maximum sequence length in the batch of k and v. - """ - q_latent = self.latent.expand(kv.size(0), -1, -1) - q = self.Wq(q_latent) - bsz, q_len, h_size = q.shape - kv = self.Wkv(kv) - query = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) - kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) - - key, value = kv[:, :, 0], kv[:, :, 1] - - query = query.permute(0, 2, 1, 3) - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) - - attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor - if attention_mask is not None: - attention_scores = attention_scores + attention_mask - - attentions_probs = F.softmax(attention_scores, dim=-1) - attentions_probs = self.drop(attentions_probs) - - attn_output = torch.matmul(attentions_probs, value) - attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)") - - attn_output = self.out_proj(attn_output) - - return attn_output - - -class NomicMultiHeadAttentionPooling(nn.Module): - def __init__( - self, - config, - ): - super().__init__() - self.prenorm = config.prenorm - self.fused_dropout_add_ln = config.fused_dropout_add_ln - - self.attn = NomicAttentionPooling(config) - activation = ( - F.sigmoid - if config.activation_function == "glu" - else (F.silu if config.activation_function == "swiglu" else F.gelu) - ) - if config.activation_function in ["glu", "swiglu", "geglu"]: - self.mlp = NomciBertGatedMLP( - config.n_embd, - hidden_features=config.n_inner, - bias1=config.mlp_fc1_bias, - bias2=config.mlp_fc2_bias, - activation=activation, - fused_bias_fc=config.fused_bias_fc, - ) - else: - self.mlp = NomicBertMLP( - config.n_embd, - hidden_features=config.n_inner, - bias1=config.mlp_fc1_bias, - bias2=config.mlp_fc2_bias, - activation=activation, - fused_bias_fc=config.fused_bias_fc, - ) - - self.dropout1 = nn.Dropout(config.resid_pdrop) - self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.dropout2 = nn.Dropout(config.resid_pdrop) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) - mixer_subset: for cross-attention only. If not None, will take a subset of x - before applying the query projection. Useful for e.g., ViT where we only care - about the CLS token in the last layer. - """ - - attn_outputs = self.attn( - hidden_states, - attention_mask=attention_mask, - ) - - normed = self.norm1(attn_outputs) - hidden_states = hidden_states + self.mlp(normed) - - return hidden_states - - -class NomicVisionPreTrainedModel(PreTrainedModel): - """An abstract class to handle weights initialization and - a simple interface for dowloading and loading pretrained models. - """ - - config_class = NomicBertConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Block"] - _skip_keys_device_placement = "past_key_values" - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config) - if not isinstance(config, GPT2Config): - raise ValueError( - "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. " - "To create a model from a Google pretrained model use " - "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( - self.__class__.__name__, self.__class__.__name__ - ) - ) - self.config = config - - -class NomicVisionModel(NomicVisionPreTrainedModel): - def __init__(self, config): - super().__init__(config) - - self.embeddings = NomicVisionPatchEmbeddings(config) - self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)]) - - self.selector = NomicMultiHeadAttentionPooling(config) - - self.global_pool = getattr(config, "global_pool", None) - self.num_prefix_tokens = (1 if not getattr(config, "no_cls_token", False) else 0) + getattr( - config, "register_tokens", 0 - ) - - self.apply(partial(_init_weights, initializer_range=config.initializer_range)) - - def forward( - self, - pixel_values, - attention_mask=None, - position_ids=None, - token_type_ids=None, - return_dict=None, - matryoshka_dim=None, - ): - embeddings, rope = self.embeddings(pixel_values) - - original_dtype = embeddings.dtype - - hidden_states = embeddings - # unused but easier to pass to gradient checkpointing as words - residual = None - for layer in self.layers: - # need to pass none for backwards compatability - hidden_states, _, residual = layer( - hidden_states, None, residual=residual, is_padded_inputs=False, rope=rope - ) - - hidden_states = hidden_states + residual - if self.global_pool == "avg": - hidden_states = hidden_states[:, self.num_prefix_tokens :].mean(dim=1) - - pooled_output = self.selector(hidden_states) - - return BaseModelOutputWithPast( - last_hidden_state=pooled_output, - hidden_states=hidden_states, - ) +# Copyright (c) 2022, Tri Dao. +# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation. +# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py +# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py + +import collections +import logging + +# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py +import math +import os +import re +from collections import OrderedDict +from functools import partial +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from safetensors.torch import load_file as safe_load_file +from torch.nn.modules.utils import _pair +from transformers import GPT2Config, PreTrainedModel, ViTConfig, ViTModel +from transformers.models.bert.modeling_bert import ( + BaseModelOutputWithPoolingAndCrossAttentions, + MaskedLMOutput, + SequenceClassifierOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + ModelOutput, + TokenClassifierOutput, +) +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME +from transformers.utils.hub import cached_file, get_checkpoint_shard_files + +from .configuration_hf_nomic_bert import NomicBertConfig + +try: + from torch.nn.functional import scaled_dot_product_attention +except ImportError: + scaled_dot_product_attention = None + +logger = logging.getLogger(__name__) + + +# adapted from flash attention, added safe serialization option for hf models +def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None): + # If not fp32, then we don't want to load directly to the GPU + mapped_device = "cpu" if dtype not in [torch.float32, None] else device + is_sharded = False + load_safe = False + resolved_archive_file = None + + weights_path = os.path.join(model_name, WEIGHTS_NAME) + weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME) + safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME) + safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME) + + if os.path.isfile(weights_path): + resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) + elif os.path.isfile(weights_index_path): + resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False) + is_sharded = True + elif os.path.isfile(safe_weights_path): + resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) + load_safe = True + elif os.path.isfile(safe_weights_index_path): + resolved_archive_file = cached_file( + model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False + ) + is_sharded = True + load_safe = True + else: # Try loading from HF hub instead of from local files + resolved_archive_file = None + for weight_name in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]: + resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False) + if resolved_archive_file is not None: + if weight_name in [SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]: + load_safe = True + if weight_name in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]: + is_sharded = True + break + + if resolved_archive_file is None: + raise EnvironmentError(f"Model name {model_name} was not found.") + + if load_safe: + loader = partial(safe_load_file, device=mapped_device) + else: + loader = partial(torch.load, map_location=mapped_device) + + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different + # checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(model_name, resolved_archive_file) + state_dict = {} + for sharded_file in resolved_archive_file: + state_dict.update(loader(sharded_file)) + else: + state_dict = loader(resolved_archive_file) + # Convert dtype before moving to GPU to save memory + if dtype is not None: + state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} + state_dict = {k: v.to(device=device) for k, v in state_dict.items()} + return state_dict + + +def filter_shapes(state_dict, model): + """ + Filters the state dict to match the current model shape. + """ + filtered_state_dict = {} + for key, value in state_dict.items(): + if key in model.state_dict(): + if value.shape == model.state_dict()[key].shape: + filtered_state_dict[key] = value + return filtered_state_dict + + +def remap_bert_state_dict( + state_dict, + config, + remove_bert=False, + remove_cls_weights=False, + add_pooling_layer=False, +): + """ + Map the state_dict of a Huggingface BERT model to be flash_attn compatible. + """ + + def add_bert_prefix(key): + # prepend bert. to the key + if key.startswith("bert.") or key.startswith("cls."): + return key + return f"bert.{key}" + + state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items()) + + # LayerNorm + def key_mapping_ln_gamma_beta(key): + key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key) + key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key) + return key + + state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()) + + # Layers + def key_mapping_layers(key): + return re.sub(r"^bert.encoder.layer\.", "bert.encoder.layers.", key) + + state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key) + key = re.sub( + r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)", + r"bert.encoder.layers.\1.norm1.\2", + key, + ) + key = re.sub( + r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)", + r"bert.encoder.layers.\1.norm2.\2", + key, + ) + key = re.sub( + r"^cls.predictions.transform.LayerNorm.(weight|bias)", + r"cls.predictions.transform.layer_norm.\1", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + key = re.sub( + r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)", + r"bert.encoder.layers.\1.mlp.fc1.\2", + key, + ) + key = re.sub( + r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)", + r"bert.encoder.layers.\1.mlp.fc2.\2", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + last_layer_subset = getattr(config, "last_layer_subset", False) + for d in range(config.num_hidden_layers): + if f"bert.encoder.layers.{d}.attention.self.query.weight" not in state_dict: + continue + Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight") + Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight") + Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight") + bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias") + bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias") + bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias") + if not (last_layer_subset and d == config.num_hidden_layers - 1): + state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) + state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0) + else: + state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq + state_dict[f"bert.encoder.layers.{d}.attn.Wkv.weight"] = torch.cat([Wk, Wv], dim=0) + state_dict[f"bert.encoder.layers.{d}.attn.Wq.bias"] = bq + state_dict[f"bert.encoder.layers.{d}.attn.Wkv.bias"] = torch.cat([bk, bv], dim=0) + + def key_mapping_attn(key): + return re.sub( + r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)", + r"bert.encoder.layers.\1.attn.out_proj.\2", + key, + ) + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + def key_mapping_decoder_bias(key): + return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key) + + # remove nsp weights, we don't use + state_dict.pop("cls.seq_relationship.weight", None) + state_dict.pop("cls.seq_relationship.bias", None) + state_dict.pop("bert.embeddings.position_ids", None) + + state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items()) + + if remove_cls_weights: + cls_weights = [ + "cls.predictions.decoder.bias", + "cls.predictions.transform.dense.weight", + "cls.predictions.transform.dense.bias", + "cls.predictions.transform.layer_norm.weight", + "cls.predictions.transform.layer_norm.bias", + "cls.predictions.decoder.weight", + ] + for weight in cls_weights: + state_dict.pop(weight, None) + + # Word embedding + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + if pad_vocab_size_multiple > 1: + word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"] + state_dict["bert.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0]) + ) + if not remove_cls_weights: + if "cls.predictions.decoder.weight" not in state_dict: + state_dict['cls.predictions.decoder.weight'] = state_dict['bert.embeddings.word_embeddings.weight'].clone() + else: + decoder_weight = state_dict["cls.predictions.decoder.weight"] + state_dict["cls.predictions.decoder.weight"] = F.pad( + decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0]) + ) + # If the vocab was padded, we want to set the decoder bias for those padded indices to be + # strongly negative (i.e. the decoder shouldn't predict those indices). + # TD [2022-05-09]: I don't think it affects the MLPerf training. + if "cls.predictions.decoder.bias" in state_dict: + decoder_bias = state_dict["cls.predictions.decoder.bias"] + state_dict["cls.predictions.decoder.bias"] = F.pad( + decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0 + ) + + if add_pooling_layer is False: + pooler_weights = [ + "bert.pooler.dense.weight", + "bert.pooler.dense.bias", + ] + for key in pooler_weights: + state_dict.pop(key, None) + + if remove_bert: + + def remove_bert_prefix(key): + key = re.sub(r"^bert.", "", key) + return key + + state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items()) + + return state_dict + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + print( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_tf_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + return tensor + + +class NomicBertPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + + config_class = NomicBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Block"] + _skip_keys_device_placement = "past_key_values" + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config) + if not isinstance(config, GPT2Config): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + ) + ) + self.config = config + + @classmethod + def from_pretrained(cls, model_name, config=None, *inputs, **kwargs): + """ + Instantiate a NomicBertPreTrainedModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + + Params: + pretrained_model_name_or_path: either: + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a NomicBertForPretraining instance + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `model.chkpt` a TensorFlow checkpoint + *inputs, **kwargs: additional input for the specific NomicBert class + (ex: num_labels for NomicBertForSequenceClassification) + """ + # Instantiate model. + if config is None: + config = cls.config_class.from_pretrained(model_name) + remove_cls = cls != NomicBertForPreTraining + remove_bert_prefix = cls not in [NomicBertForPreTraining, NomicBertForSequenceClassification, NomicBertForTokenClassification, NomicBertForMultipleChoice, NomicBertForQuestionAnswering] + ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False) + num_labels = kwargs.pop("num_labels", None) + rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None) + strict = kwargs.pop("strict", True) + dtype = kwargs.pop("torch_dtype", None) + if rotary_scaling_factor: + config.rotary_scaling_factor = rotary_scaling_factor + + if config.n_positions <= 0 and config.rotary_emb_fraction > 0: + config.n_positions = 2048 + if num_labels: + config.num_labels = num_labels + + if "add_pooling_layer" in kwargs: + model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer")) + else: + if cls == NomicBertModel: + model = cls(config, *inputs, add_pooling_layer=False) + else: + model = cls(config, *inputs) + + if dtype is not None: + model = model.to(dtype=dtype) + # TODO: fix this + # Assuming we know what we're doing when loading from disk + # Prob a bad assumption but i'm tired and want to train this asap + if os.path.exists(model_name): + model_path = f"{model_name}/pytorch_model.bin" + if os.path.exists(model_path): + state_dict = torch.load(f"{model_name}/pytorch_model.bin") + else: + model_path = f"{model_name}/model.safetensors" + if not os.path.exists(model_path): + raise ValueError(f"Model path {model_path} not found") + state_dict = safe_load_file(model_path) + + if ignore_mismatched_shapes: + state_dict = filter_shapes(state_dict, model) + load_return = model.load_state_dict(state_dict, strict=False) + else: + # TODO: can probably check config class and see if we need to remap from a bert model + state_dict = state_dict_from_pretrained(model_name, dtype=dtype) + state_dict = remap_bert_state_dict( + state_dict, + config, + remove_bert=remove_bert_prefix, + remove_cls_weights=remove_cls, + add_pooling_layer=getattr(config, "add_pooling_layer", False), + ) + if ignore_mismatched_shapes: + state_dict = filter_shapes(state_dict, model) + + load_return = model.load_state_dict(state_dict, strict=strict) + logger.warning(load_return) + return model + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, NomicBertEncoder): + module.gradient_checkpointing = value + + +# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748 +def _init_weights(module, initializer_range=0.02): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + if module.padding_idx is not None: + nn.init.zeros_(module.weight[module.padding_idx]) + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): + """ + Create 2D sin/cos positional embeddings. + + Args: + embed_dim (`int`): + Embedding dimension. + grid_size (`int`): + The grid height and width. + add_cls_token (`bool`, *optional*, defaults to `False`): + Whether or not to add a classification (CLS) token. + + Returns: + (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the + position embeddings (with or without classification token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if add_cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]: + """generate N-D grid in dimension order. + + The ndgrid function is like meshgrid except that the order of the first two input arguments are switched. + + That is, the statement + [X1,X2,X3] = ndgrid(x1,x2,x3) + + produces the same result as + + [X2,X1,X3] = meshgrid(x2,x1,x3) + + This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make + torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy'). + + """ + try: + return torch.meshgrid(*tensors, indexing='ij') + except TypeError: + # old PyTorch < 1.10 will follow this path as it does not have indexing arg, + # the old behaviour of meshgrid was 'ij' + return torch.meshgrid(*tensors) + + +def build_fourier_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + num_bands: int = 64, + max_res: int = 224, + temperature: float = 10000.0, + linear_bands: bool = False, + include_grid: bool = False, + in_pixels: bool = True, + ref_feat_shape: Optional[List[int]] = None, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> List[torch.Tensor]: + """ + + Args: + feat_shape: Feature shape for embedding. + bands: Pre-calculated frequency bands. + num_bands: Number of frequency bands (determines output dim). + max_res: Maximum resolution for pixel based freq. + temperature: Temperature for non-pixel freq. + linear_bands: Linear band spacing for pixel based freq. + include_grid: Include the spatial grid in output. + in_pixels: Output in pixel freq. + ref_feat_shape: Reference feature shape for resize / fine-tune. + dtype: Output dtype. + device: Output device. + + Returns: + + """ + if bands is None: + if in_pixels: + bands = pixel_freq_bands( + num_bands, + float(max_res), + linear_bands=linear_bands, + device=device, + ) + else: + bands = freq_bands( + num_bands, + temperature=temperature, + step=1, + device=device, + ) + else: + if device is None: + device = bands.device + if dtype is None: + dtype = bands.dtype + + if in_pixels: + t = [torch.linspace(-1.0, 1.0, steps=s, device=device, dtype=torch.float32) for s in feat_shape] + else: + t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape] + + if ref_feat_shape is not None: + # eva's scheme for resizing rope embeddings (ref shape = pretrain) + t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)] + + grid = torch.stack(ndgrid(t), dim=-1) + grid = grid.unsqueeze(-1) + pos = grid * bands + + pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype) + out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos] + return out + + +def build_rotary_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + dim: int = 64, + max_res: int = 224, + temperature: float = 10000.0, + linear_bands: bool = False, + in_pixels: bool = True, + ref_feat_shape: Optional[List[int]] = None, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +): + """ + + Args: + feat_shape: Spatial shape of the target tensor for embedding. + bands: Optional pre-generated frequency bands + dim: Output dimension of embedding tensor. + max_res: Maximum resolution for pixel mode. + temperature: Temperature (inv freq) for non-pixel mode + linear_bands: Linearly (instead of log) spaced bands for pixel mode + in_pixels: Pixel vs language (inv freq) mode. + dtype: Output dtype. + device: Output device. + + Returns: + + """ + sin_emb, cos_emb = build_fourier_pos_embed( + feat_shape, + bands=bands, + num_bands=dim // 4, + max_res=max_res, + temperature=temperature, + linear_bands=linear_bands, + in_pixels=in_pixels, + ref_feat_shape=ref_feat_shape, + device=device, + dtype=dtype, + ) + num_spatial_dim = 1 + # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks + for x in feat_shape: + num_spatial_dim *= x + sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) + cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) + return sin_emb, cos_emb + + +def freq_bands( + num_bands: int, + temperature: float = 10000.0, + step: int = 2, + device: Optional[torch.device] = None, +) -> torch.Tensor: + exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands + bands = 1.0 / (temperature**exp) + return bands + + +def pixel_freq_bands( + num_bands: int, + max_freq: float = 224.0, + linear_bands: bool = True, + device: Optional[torch.device] = None, +): + if linear_bands: + bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device) + else: + bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device) + return bands * torch.pi + + +def rot(x): + return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) + + +def apply_rot_embed_cat(x: torch.Tensor, emb): + sin_emb, cos_emb = emb.tensor_split(2, -1) + if sin_emb.ndim == 3: + return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x) + return x * cos_emb + rot(x) * sin_emb + + +# taken from https://github.com/huggingface/pytorch-image-models/blob/cb0e4391beedcc5ac3ae4bce16561b95c326f32c/timm/layers/pos_embed_sincos.py#L363 +class NomicVisionRotaryEmbeddingCat(nn.Module): + """Rotary position embedding w/ concatenatd sin & cos + + The following impl/resources were referenced for this impl: + * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py + * https://blog.eleuther.ai/rotary-embeddings/ + """ + + def __init__( + self, + dim, + max_res=224, + temperature=10000, + in_pixels=True, + linear_bands: bool = False, + feat_shape: Optional[List[int]] = None, + ref_feat_shape: Optional[List[int]] = None, + ): + super().__init__() + self.dim = dim + self.max_res = max_res + self.temperature = temperature + self.in_pixels = in_pixels + self.feat_shape = feat_shape + self.ref_feat_shape = ref_feat_shape + + if feat_shape is None: + # only cache bands + if in_pixels: + bands = pixel_freq_bands( + dim // 4, + float(max_res), + linear_bands=linear_bands, + ) + else: + bands = freq_bands( + dim // 4, + temperature=temperature, + step=1, + ) + self.register_buffer( + 'bands', + bands, + persistent=False, + ) + self.pos_embed = None + else: + # cache full sin/cos embeddings if shape provided up front + embeds = build_rotary_pos_embed( + feat_shape=feat_shape, + dim=dim, + max_res=max_res, + linear_bands=linear_bands, + in_pixels=in_pixels, + ref_feat_shape=self.ref_feat_shape, + ) + self.bands = None + self.register_buffer( + 'pos_embed', + torch.cat(embeds, -1), + persistent=False, + ) + + def get_embed(self, shape: Optional[List[int]] = None): + if self.bands is not None and shape is not None: + # rebuild embeddings every call, use if target shape changes + embeds = build_rotary_pos_embed( + shape, + self.bands, + in_pixels=self.in_pixels, + ref_feat_shape=self.ref_feat_shape, + ) + return torch.cat(embeds, -1) + elif self.pos_embed is not None: + return self.pos_embed + else: + assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands" + + def forward(self, x): + # assuming channel-first tensor where spatial dim are >= 2 + pos_embed = self.get_embed(x.shape[2:]) + return apply_rot_embed_cat(x, pos_embed) + + +class NomicVisionPatchEmbeddings(nn.Module): + def __init__( + self, + config, + ): + super().__init__() + img_size = _pair(config.img_size) + patch_size = _pair(config.patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + self.proj = nn.Linear( + config.num_channels * patch_size[0] * patch_size[1], config.n_embd, bias=config.patch_embed_bias + ) + + self.learned_pos_embedding = False + self.sinusoidal_pos_embedding = False + self.no_embed_class = getattr(config, "no_embed_class", False) + + self.cls_token = ( + nn.Parameter(torch.zeros(1, 1, config.n_embd)) if not getattr(config, "no_cls_token", False) else None + ) + if config.learned_pos_embedding: + # this is the default in DINO + self.learned_pos_embedding = True + # hack for timm dinov2 with registers + num_patches = self.num_patches if getattr(config, "register_tokens", 0) > 0 else self.num_patches + 1 + self.pos_embed = ( + nn.Parameter(torch.randn(1, num_patches, config.n_embd) * 0.02) + if getattr(config, "use_pos_embed", True) + else None + ) + elif getattr(config, "sinusoidal_pos_embedding", False): + self.sinusoidal_pos_embedding = True + if getattr(config, "use_pos_embed", True): + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, config.n_embd), requires_grad=False) + pos_embed = get_2d_sincos_pos_embed(config.n_embd, self.grid_size[0], add_cls_token=True) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).to(self.pos_embed)) + else: + self.pos_embed = None + else: + self.pos_embed = ( + nn.Parameter(torch.randn(1, self.num_patches + 1, config.n_embd) * 0.02) + if getattr(config, "use_pos_embed", True) + else None + ) + + if getattr(config, "register_tokens", 0) > 0: + self.reg_token = nn.Parameter(torch.randn(1, config.register_tokens, config.n_embd) * 0.02) + else: + self.reg_token = None + + if config.mask_token: + self.mask_token = nn.Parameter(torch.zeros(1, config.n_embd)) + + self.patch_dropout = nn.Identity() + + if getattr(config, "use_rotary_pos_emb", False): + ref_feat_shape = getattr(config, "ref_feat_shape", None) + ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None + self.rope = NomicVisionRotaryEmbeddingCat( + config.n_embd // config.n_head, + in_pixels=False, + feat_shape=self.grid_size, + ref_feat_shape=ref_feat_shape, + ) + else: + self.rope = None + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.pos_embed.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = embeddings.shape[-1] + height = height // self.patch_size[0] + width = width // self.patch_size[1] + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, x): + # deepspeed case where the input is in fp32 + if x.dtype != self.proj.weight.dtype: + x = x.to(dtype=self.proj.weight.dtype) + + _, _, height, width = x.shape + x = self.proj( + rearrange( + x, + "b c (h p1) (w p2) -> b h w (c p1 p2)", + p1=self.patch_size[0], + p2=self.patch_size[1], + ) + ) + embeddings = rearrange(x, "b h w c -> b (h w) c") + + to_cat = [] + if self.cls_token is not None: + if self.sinusoidal_pos_embedding: + cls_token = self.cls_token + self.pos_embed[:, 0] + cls_token = cls_token.expand(embeddings.shape[0], -1, -1) + to_cat += [cls_token] + else: + cls_token = self.cls_token.expand(embeddings.shape[0], 1, -1) + to_cat += [cls_token] + + if self.reg_token is not None: + to_cat += [self.reg_token.expand(embeddings.shape[0], -1, -1)] + + rot_pos_embed = self.rope.get_embed() if self.rope is not None else None + + if self.no_embed_class: + if self.learned_pos_embedding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + if self.pos_embed is not None: + embeddings = embeddings + self.pos_embed + if to_cat: + embeddings = torch.cat(to_cat + [embeddings], dim=1) + else: + if to_cat: + embeddings = torch.cat(to_cat + [embeddings], dim=1) + if self.learned_pos_embedding: + if self.pos_embed is not None: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + if self.pos_embed is not None: + embeddings = embeddings + self.pos_embed + + embeddings = self.patch_dropout(embeddings) + + return embeddings, rot_pos_embed + + +class NomicBertEmbeddings(nn.Module): + def __init__(self, config): + """ + If max_position_embeddings <= 0, there's no position embeddings + If type_vocab_size <= 0, there's no token type embeddings + """ + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0 + self.type_vocab_size = config.type_vocab_size + if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0: + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, + config.hidden_size, + ) + if self.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + def forward(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None): + """ + input_ids: (batch, seqlen) + position_ids: (batch, seqlen) + token_type_ids: (batch, seqlen) + """ + if inputs_embeds is None: + embeddings = self.word_embeddings(input_ids) + else: + embeddings = inputs_embeds + batch_size, seqlen, _ = embeddings.shape + + if self.type_vocab_size > 0: + if token_type_ids is None: + token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=embeddings.device) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = embeddings + token_type_embeddings + + if self.max_position_embeddings > 0: + if position_ids is None: + position_ids = torch.arange(seqlen, dtype=torch.long, device=embeddings.device) + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + return embeddings + + +class NomicBertMLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + activation=F.gelu, + bias1=True, + bias2=True, + return_residual=False, + fused_bias_fc=False, + ): + super().__init__() + out_features = out_features if out_features is not None else in_features + hidden_features = hidden_features if hidden_features is not None else in_features * 4 + self.return_residual = return_residual + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1) + approximate = "tanh" if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none" + self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2) + + def forward(self, x): + y = self.fc1(x) + y = self.activation(y) + y = self.fc2(y) + return y if not self.return_residual else (y, x) + + +class NomciBertGatedMLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + activation=F.sigmoid, + bias1=True, + bias2=True, + multiple_of=256, + return_residual=False, + fused_bias_fc=True, + device=None, + dtype=None, + norm_layer=False, + ): + super().__init__() + out_features = out_features if out_features is not None else in_features + hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3) + hidden_features = int((hidden_features + multiple_of - 1) // multiple_of * multiple_of) + self.return_residual = return_residual + + self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1) + self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1) + self.activation = activation + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2) + self.norm = nn.LayerNorm(hidden_features) if norm_layer else nn.Identity() + + def forward(self, x): + y = self.fc11(x) + gate = self.fc12(x) + if self.activation == F.sigmoid: # Special case for GLU + y = F.glu(torch.cat([y, gate], dim=-1), dim=-1) + else: + y = y * self.activation(gate) + + # eva uses layer norm after the activation + y = self.norm(y) + + y = self.fc2(y) + return y if not self.return_residual else (y, x) + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) + + +def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos, sin = ( + cos[offset : offset + x.shape[1]], + sin[offset : offset + x.shape[1]], + ) + cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], + dim=-1, + ) + + +class NomicBertRotaryEmbedding(nn.Module): + def __init__( + self, + dim: int, + base=10000.0, + interleaved=False, + scale_base=None, + pos_idx_in_fp32=True, + device=None, + ): + """ + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, + otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. In most cases this would + be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, + we add this option. + """ + super().__init__() + self.dim = dim + self.base = float(base) + self.pos_idx_in_fp32 = pos_idx_in_fp32 + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = self._compute_inv_freq(device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.interleaved = interleaved + self.scale_base = scale_base + scale = ( + (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) + if scale_base is not None + else None + ) + self.register_buffer("scale", scale, persistent=False) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + def _compute_inv_freq(self, device=None): + return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def forward( + self, + qkv: torch.Tensor, + kv: Optional[torch.Tensor] = None, + seqlen_offset: Union[int, torch.Tensor] = 0, + max_seqlen: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, + else it's just q of shape (batch, seqlen, nheads, headdim) + kv: (batch, seqlen, 2, nheads, headdim) + seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one + should pass in max_seqlen, which will update the cos / sin cache up to that length. + Apply rotary embedding *inplace* to qkv and / or kv. + """ + seqlen = qkv.shape[1] + if seqlen > self._seq_len_cached: + self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype) + elif max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + elif isinstance(seqlen_offset, int): + self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) + + q_rot = apply_rotary_emb(qkv[:, :, 0], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved) + k_rot = apply_rotary_emb(qkv[:, :, 1], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved) + return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2) + + +class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding): + def __init__(self, rotary_scaling_factor, max_position_embeddings, **kwargs): + super().__init__(**kwargs) + self.rotary_scaling_factor = rotary_scaling_factor + self.max_position_embeddings = max_position_embeddings + + def _compute_inv_freq(self, base=None, device=None): + if base is None: + base = self.base + return 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if seqlen > self.max_position_embeddings: + base = self.base * ( + (self.rotary_scaling_factor * seqlen / self.max_position_embeddings) - (self.rotary_scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = self._compute_inv_freq(base=base, device=device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + if seqlen > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + else: + base = self.base + inv_freq = self._compute_inv_freq(device=device, base=base) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + +class NomicBertAttention(nn.Module): + """Multi-head self-attention and cross-attention""" + + def __init__( + self, + config, + ) -> None: + """ + num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + super().__init__() + self.embed_dim = config.n_embd + self.use_flash_attn = config.use_flash_attn + self.fused_bias_fc = config.fused_bias_fc + + self.num_heads = config.n_head + self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads + assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" + self.head_dim = self.embed_dim // self.num_heads + # we don't really support mqa / gqa for now + qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) + + self.register_buffer( + "norm_factor", + torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()), + persistent=False, + ) + + self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction + if self.rotary_emb_dim > 0: + if getattr(config, "rotary_scaling_factor", None): + self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding( + dim=self.rotary_emb_dim, + base=config.rotary_emb_base, + scale_base=config.rotary_emb_scale_base, + interleaved=config.rotary_emb_interleaved, + rotary_scaling_factor=config.rotary_scaling_factor, + max_position_embeddings=config.max_trained_positions, + ) + else: + self.rotary_emb = NomicBertRotaryEmbedding( + dim=self.rotary_emb_dim, + base=config.rotary_emb_base, + scale_base=config.rotary_emb_scale_base, + interleaved=config.rotary_emb_interleaved, + ) + # bug in xformers: https://github.com/facebookresearch/xformers/issues/841 + # uses the head dimension instead of the sequence dimension + self.rotary_head_dim = getattr(config, "rotary_head_dim", False) + + self.Wqkv = nn.Linear(self.embed_dim, qkv_dim, bias=config.qkv_proj_bias) + + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias) + self.causal = config.causal + self.drop = nn.Dropout(config.attn_pdrop) + self.num_prefix_tokens = max(getattr(config, "register_tokens", 1), 1) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + is_padded_inputs: Optional[bool] = True, + cu_seqlens: Optional[torch.Tensor] = None, + max_seq_len: Optional[int] = None, + rope: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + has_layer_past = past_key_value is not None + + if has_layer_past: + past_key_value = past_key_value[0] + past_len = past_key_value[1] + else: + past_len = 0 + + qkv = self.Wqkv(hidden_states) + qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) + + past_key_value = (past_key_value, past_len + qkv.size(1)) if use_cache else None + + if self.rotary_emb_dim > 0: + if self.rotary_head_dim: + qkv = rearrange(qkv, "b s three h d -> b h three s d") + qkv = self.rotary_emb(qkv, seqlen_offset=past_len) + + if self.rotary_head_dim: + qkv = rearrange(qkv, "b h three s d -> b s three h d") + elif rope is not None: + q, k, v = qkv.permute(0, 3, 1, 2, 4).unbind(dim=-2) + q = torch.cat( + [q[:, :, : self.num_prefix_tokens], apply_rot_embed_cat(q[:, :, self.num_prefix_tokens :], rope)], dim=2 + ).type_as(q) + k = torch.cat( + [k[:, :, : self.num_prefix_tokens], apply_rot_embed_cat(k[:, :, self.num_prefix_tokens :], rope)], dim=2 + ).type_as(q) + + qkv = torch.stack([q, k, v], dim=-2) + qkv = rearrange(qkv, "b h s three d -> b s three h d") + + query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + if scaled_dot_product_attention is not None: + attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=self.drop.p, is_causal=False + ) + else: + attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + attentions_probs = F.softmax(attention_scores, dim=-1) + attentions_probs = self.drop(attentions_probs) + + attn_output = torch.matmul(attentions_probs, value) + + attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)") + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class NomicBertBlock(NomicBertPreTrainedModel): + def __init__( + self, + config, + ): + super().__init__(config=config) + self.prenorm = config.prenorm + self.fused_dropout_add_ln = config.fused_dropout_add_ln + + self.attn = NomicBertAttention(config) + activation = ( + F.sigmoid + if config.activation_function == "glu" + else (F.silu if config.activation_function == "swiglu" else F.gelu) + ) + if config.activation_function in ["glu", "swiglu", "geglu"]: + self.mlp = NomciBertGatedMLP( + config.n_embd, + hidden_features=config.n_inner, + bias1=config.mlp_fc1_bias, + bias2=config.mlp_fc2_bias, + activation=activation, + fused_bias_fc=config.fused_bias_fc, + norm_layer=getattr(config, "norm_mlp", False), + ) + else: + self.mlp = NomicBertMLP( + config.n_embd, + hidden_features=config.n_inner, + bias1=config.mlp_fc1_bias, + bias2=config.mlp_fc2_bias, + activation=activation, + fused_bias_fc=config.fused_bias_fc, + ) + + self.dropout1 = nn.Dropout(config.resid_pdrop) + self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.norm2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.dropout2 = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: torch.Tensor, + hidden_states2: torch.Tensor, + residual: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + is_padded_inputs: Optional[bool] = True, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cu_seqlens: Optional[torch.Tensor] = None, + max_seq_len: Optional[int] = None, + rope: Optional[torch.Tensor] = None, + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. + """ + if self.prenorm: + dropped = self.dropout1(hidden_states) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) + hidden_states = self.attn( + hidden_states, + attention_mask=attention_mask, + is_padded_inputs=is_padded_inputs, + cu_seqlens=cu_seqlens, + max_seq_len=max_seq_len, + rope=rope, + ) + + dropped = self.dropout2(hidden_states) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) + hidden_states = self.mlp(hidden_states) + + return hidden_states, None, residual + else: + assert residual is None + attn_outputs = self.attn( + hidden_states, + attention_mask=attention_mask, + is_padded_inputs=is_padded_inputs, + cu_seqlens=cu_seqlens, + max_seq_len=max_seq_len, + rope=rope, + ) + hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype)) + mlp_out = self.mlp(hidden_states) + + hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype)) + return hidden_states, None, None + + +class NomicBertEncoder(nn.Module): + def __init__(self, config: GPT2Config): + super().__init__() + self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)]) + self.gradient_checkpointing = False + self.config = config + + def forward( + self, + hidden_states: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + is_padded_inputs: Optional[bool] = True, + rope: Optional[torch.Tensor] = None, + ): + """If subset_mask is not None, we only want output for the subset of the sequence. + This means that we only compute the last layer output for these tokens. + subset_mask: (batch, seqlen), dtype=torch.bool + """ + hidden_states2 = None + residual = None + + for _, layer in enumerate(self.layers): + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs) + + return custom_forward + + hidden_states, hidden_states2, residual = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + hidden_states2, + residual, + attention_mask, + position_ids, + past_key_values, + is_padded_inputs, + output_attentions, + use_cache, + None, + None, + rope, + # if you freeze ANY layers, you need `use_reentrant=False` + # https://github.com/huggingface/transformers/issues/21381 + # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/7 + use_reentrant=False, + ) + + else: + hidden_states, hidden_states2, residual = layer( + hidden_states, + hidden_states2, + residual, + attention_mask, + position_ids, + None, + is_padded_inputs, + output_attentions, + use_cache, + rope=rope, + ) + return hidden_states + + +class NomicBertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.n_embd, config.n_embd) + self.activation = nn.Tanh() + + def forward(self, hidden_states, pool=True): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] if pool else hidden_states + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class NomicBertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias) + approximate = "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none" + if config.activation_function == "swiglu": + self.transform_act_fn = F.silu + else: + self.transform_act_fn = nn.GELU(approximate=approximate) + + self.layer_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.layer_norm(hidden_states) + + return hidden_states + + +class NomicBertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + + self.transform = NomicBertPredictionHeadTransform(config) + + self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=config.mlp_fc1_bias) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class NomicBertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = NomicBertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class NomicBertModel(NomicBertPreTrainedModel): + def __init__(self, config: GPT2Config, add_pooling_layer=True): + super().__init__(config) + self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + if config.vocab_size % self.pad_vocab_size_multiple != 0: + config.vocab_size += self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple) + + assert config.activation_function in [ + "gelu", + "gelu_new", + "gelu_fast", + "gelu_pytorch_tanh", + "swiglu", + "geglu", + "glu", + ] + + self.embeddings = NomicBertEmbeddings(config) + self.emb_drop = nn.Dropout(config.resid_pdrop) + self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.encoder = NomicBertEncoder(config) + self.pooler = NomicBertPooler(config) if add_pooling_layer else None + + self.apply(partial(_init_weights, initializer_range=config.initializer_range)) + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + token_type_ids=None, + return_dict=None, + matryoshka_dim=None, + inputs_embeds=None, + ): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + hidden_states = self.emb_ln(hidden_states) + hidden_states = self.emb_drop(hidden_states) + + attention_mask = self.get_extended_attention_mask(attention_mask, hidden_states.shape[:-1]) + sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict) + + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if matryoshka_dim: + sequence_output = sequence_output[:, :matryoshka_dim] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + ) + + +class NomicBertForPreTraining(NomicBertPreTrainedModel): + _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + + def __init__(self, config: GPT2Config): + super().__init__(config) + + self.bert = NomicBertModel(config, add_pooling_layer=getattr(config, "add_pooling_layer", False)) + self.cls = NomicBertPreTrainingHeads(config) + self.mlm_loss = nn.CrossEntropyLoss() + + # Initialize weights and apply final processing + self.apply(partial(_init_weights, initializer_range=config.initializer_range)) + self.tie_weights() + + def tie_weights(self): + self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight + + def forward( + self, + input_ids, + position_ids=None, + token_type_ids=None, + attention_mask=None, + labels=None, + ): + """ + If labels are provided, they must be -100 for masked out tokens (as specified in the attention + mask). + Outputs: + if `labels` and `next_sentence_label` are not `None`: + Outputs the total_loss which is the sum of the masked language modeling loss and the next + sentence classification loss. + if `labels` or `next_sentence_label` is `None`: + Outputs a tuple comprising + - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and + - the next sentence classification logits of shape [batch_size, 2]. + + """ + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask.bool() if attention_mask is not None else None, + ) + sequence_output, _ = outputs.last_hidden_state, outputs.pooler_output + + prediction_scores = self.cls(sequence_output) + + total_loss = None + if labels is not None: + masked_lm_loss = self.mlm_loss( + rearrange(prediction_scores, "... v -> (...) v"), + rearrange(labels, "... -> (...)"), + ) + total_loss = masked_lm_loss.float() + + return MaskedLMOutput( + loss=total_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=None, + ) + + +class NomicBertForSequenceClassification(NomicBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = NomicBertModel(config) + classifier_dropout = getattr(config, "classifier_dropout", config.embd_pdrop) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.n_embd, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask.bool() if attention_mask is not None else None, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = nn.MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class NomicBertForMultipleChoice(NomicBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = NomicBertModel(config, add_pooling_layer=True) + classifier_dropout = ( + getattr(config, "classifier_dropout", config.resid_pdrop) + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + unpad_inputs: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class NomicBertForTokenClassification(NomicBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = NomicBertModel(config, add_pooling_layer=False) + classifier_dropout = ( + getattr(config, "classifier_dropout", config.resid_pdrop) + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class NomicBertForQuestionAnswering(NomicBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = NomicBertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config: + return GPT2Config( + n_embd=vit_config.hidden_size, + n_layer=vit_config.num_hidden_layers, + n_head=vit_config.num_attention_heads, + n_inner=vit_config.intermediate_size, + activation_function=vit_config.hidden_act, + vocab_size=0, # no vocab since using patches + n_positions=0, # No absolute position embedding + resid_pdrop=0.0, # No dropout + embd_pdrop=getattr(vit_config, "dropout", 0.0), + attn_pdrop=vit_config.attention_probs_dropout_prob, + layer_norm_epsilon=vit_config.layer_norm_eps, + initializer_range=vit_config.initializer_range, + bos_token_id=None, + eos_token_id=None, + # These are new arguments not in the original GPT2Config + drop_path_rate=0.0, + # Why is there double layer norm?? + prepre_layernom=False, + layer_scale=False, + layer_scale_init=None, + img_size=vit_config.image_size, + patch_size=vit_config.patch_size, + num_channels=vit_config.num_channels, + prenorm=True, + parallel_block=False, + parallel_block_tied_norm=False, + rotary_emb_fraction=0, + tie_word_embeddings=False, + fused_dropout_add_ln=True, + fused_bias_fc=True, + patch_embed_bias=True, + use_flash_attn=True, + qkv_proj_bias=True, + mlp_fc1_bias=getattr(vit_config, "mlp_fc1_bias", True), + mlp_fc2_bias=getattr(vit_config, "mlp_fc2_bias", True), + use_rms_norm=False, + causal=False, + hidden_features_scaling_factor=1.0, + mask_token=False, + learned_pos_embedding=False, + patch_dropout=0, + sinusoidal_pos_embedding=vit_config.model_type == "vit_mae", + ) + + +class NomicAttentionPooling(nn.Module): + def __init__(self, config): + super().__init__() + self.embed_dim = config.n_embd + self.use_flash_attn = config.use_flash_attn + self.fused_bias_fc = config.fused_bias_fc + + self.num_heads = config.n_head + self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads + assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" + self.head_dim = self.embed_dim // self.num_heads + # we don't really support mqa / gqa for now + kv_dim = 2 * self.head_dim * self.num_heads_kv + + self.register_buffer( + "norm_factor", + torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()), + persistent=False, + ) + + self.Wq = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias) + self.Wkv = nn.Linear(self.embed_dim, kv_dim, bias=config.qkv_proj_bias) + + self.latent = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias) + self.causal = config.causal + self.drop = nn.Dropout(config.attn_pdrop) + + def init_weights(self): + trunc_normal_tf_(self.latent, std=self.embed_dim**-0.5) + + def forward( + self, + kv, + attention_mask=None, + cu_seqlens_k=None, + max_seqlen_k=None, + is_padded_inputs: Optional[bool] = True, + output_attentions: bool = False, + ): + """Implements the multihead softmax attention. + Arguments + --------- + q: The tensor containing the query. (B, Sq, H, D) + kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) + causal: if passed, will override self.causal + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + max_seqlen: int. Maximum sequence length in the batch of q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_k: int. Maximum sequence length in the batch of k and v. + """ + q_latent = self.latent.expand(kv.size(0), -1, -1) + q = self.Wq(q_latent) + bsz, q_len, h_size = q.shape + kv = self.Wkv(kv) + query = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) + kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) + + key, value = kv[:, :, 0], kv[:, :, 1] + + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + + attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + attentions_probs = F.softmax(attention_scores, dim=-1) + attentions_probs = self.drop(attentions_probs) + + attn_output = torch.matmul(attentions_probs, value) + attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)") + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class NomicMultiHeadAttentionPooling(nn.Module): + def __init__( + self, + config, + ): + super().__init__() + self.prenorm = config.prenorm + self.fused_dropout_add_ln = config.fused_dropout_add_ln + + self.attn = NomicAttentionPooling(config) + activation = ( + F.sigmoid + if config.activation_function == "glu" + else (F.silu if config.activation_function == "swiglu" else F.gelu) + ) + if config.activation_function in ["glu", "swiglu", "geglu"]: + self.mlp = NomciBertGatedMLP( + config.n_embd, + hidden_features=config.n_inner, + bias1=config.mlp_fc1_bias, + bias2=config.mlp_fc2_bias, + activation=activation, + fused_bias_fc=config.fused_bias_fc, + ) + else: + self.mlp = NomicBertMLP( + config.n_embd, + hidden_features=config.n_inner, + bias1=config.mlp_fc1_bias, + bias2=config.mlp_fc2_bias, + activation=activation, + fused_bias_fc=config.fused_bias_fc, + ) + + self.dropout1 = nn.Dropout(config.resid_pdrop) + self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.dropout2 = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. + """ + + attn_outputs = self.attn( + hidden_states, + attention_mask=attention_mask, + ) + + normed = self.norm1(attn_outputs) + hidden_states = hidden_states + self.mlp(normed) + + return hidden_states + + +class NomicVisionPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + + config_class = NomicBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Block"] + _skip_keys_device_placement = "past_key_values" + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config) + if not isinstance(config, GPT2Config): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + ) + ) + self.config = config + + +class NomicVisionModel(NomicVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = NomicVisionPatchEmbeddings(config) + self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)]) + + self.selector = NomicMultiHeadAttentionPooling(config) + + self.global_pool = getattr(config, "global_pool", None) + self.num_prefix_tokens = (1 if not getattr(config, "no_cls_token", False) else 0) + getattr( + config, "register_tokens", 0 + ) + + self.apply(partial(_init_weights, initializer_range=config.initializer_range)) + + def forward( + self, + pixel_values, + attention_mask=None, + position_ids=None, + token_type_ids=None, + return_dict=None, + matryoshka_dim=None, + ): + embeddings, rope = self.embeddings(pixel_values) + + original_dtype = embeddings.dtype + + hidden_states = embeddings + # unused but easier to pass to gradient checkpointing as words + residual = None + for layer in self.layers: + # need to pass none for backwards compatability + hidden_states, _, residual = layer( + hidden_states, None, residual=residual, is_padded_inputs=False, rope=rope + ) + + hidden_states = hidden_states + residual + if self.global_pool == "avg": + hidden_states = hidden_states[:, self.num_prefix_tokens :].mean(dim=1) + + pooled_output = self.selector(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=pooled_output, + hidden_states=hidden_states, + )