Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from funasr_detach.models.data2vec.multihead_attention import MultiheadAttention | |
| class Fp32LayerNorm(nn.LayerNorm): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def forward(self, input): | |
| output = F.layer_norm( | |
| input.float(), | |
| self.normalized_shape, | |
| self.weight.float() if self.weight is not None else None, | |
| self.bias.float() if self.bias is not None else None, | |
| self.eps, | |
| ) | |
| return output.type_as(input) | |
| class Fp32GroupNorm(nn.GroupNorm): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def forward(self, input): | |
| output = F.group_norm( | |
| input.float(), | |
| self.num_groups, | |
| self.weight.float() if self.weight is not None else None, | |
| self.bias.float() if self.bias is not None else None, | |
| self.eps, | |
| ) | |
| return output.type_as(input) | |
| class TransposeLast(nn.Module): | |
| def __init__(self, deconstruct_idx=None): | |
| super().__init__() | |
| self.deconstruct_idx = deconstruct_idx | |
| def forward(self, x): | |
| if self.deconstruct_idx is not None: | |
| x = x[self.deconstruct_idx] | |
| return x.transpose(-2, -1) | |
| class SamePad(nn.Module): | |
| def __init__(self, kernel_size, causal=False): | |
| super().__init__() | |
| if causal: | |
| self.remove = kernel_size - 1 | |
| else: | |
| self.remove = 1 if kernel_size % 2 == 0 else 0 | |
| def forward(self, x): | |
| if self.remove > 0: | |
| x = x[:, :, : -self.remove] | |
| return x | |
| def pad_to_multiple(x, multiple, dim=-1, value=0): | |
| # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41 | |
| if x is None: | |
| return None, 0 | |
| tsz = x.size(dim) | |
| m = tsz / multiple | |
| remainder = math.ceil(m) * multiple - tsz | |
| if m.is_integer(): | |
| return x, 0 | |
| pad_offset = (0,) * (-1 - dim) * 2 | |
| return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder | |
| def gelu_accurate(x): | |
| if not hasattr(gelu_accurate, "_a"): | |
| gelu_accurate._a = math.sqrt(2 / math.pi) | |
| return ( | |
| 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) | |
| ) | |
| def gelu(x: torch.Tensor) -> torch.Tensor: | |
| return torch.nn.functional.gelu(x.float()).type_as(x) | |
| def get_available_activation_fns(): | |
| return [ | |
| "relu", | |
| "gelu", | |
| "gelu_fast", # deprecated | |
| "gelu_accurate", | |
| "tanh", | |
| "linear", | |
| ] | |
| def get_activation_fn(activation: str): | |
| """Returns the activation function corresponding to `activation`""" | |
| if activation == "relu": | |
| return F.relu | |
| elif activation == "gelu": | |
| return gelu | |
| elif activation == "gelu_accurate": | |
| return gelu_accurate | |
| elif activation == "tanh": | |
| return torch.tanh | |
| elif activation == "linear": | |
| return lambda x: x | |
| elif activation == "swish": | |
| return torch.nn.SiLU | |
| else: | |
| raise RuntimeError("--activation-fn {} not supported".format(activation)) | |
| def init_bert_params(module): | |
| """ | |
| Initialize the weights specific to the BERT Model. | |
| This overrides the default initializations depending on the specified arguments. | |
| 1. If normal_init_linear_weights is set then weights of linear | |
| layer will be initialized using the normal distribution and | |
| bais will be set to the specified value. | |
| 2. If normal_init_embed_weights is set then weights of embedding | |
| layer will be initialized using the normal distribution. | |
| 3. If normal_init_proj_weights is set then weights of | |
| in_project_weight for MultiHeadAttention initialized using | |
| the normal distribution (to be validated). | |
| """ | |
| def normal_(data): | |
| # with FSDP, module params will be on CUDA, so we cast them back to CPU | |
| # so that the RNG is consistent with and without FSDP | |
| data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) | |
| if isinstance(module, nn.Linear): | |
| normal_(module.weight.data) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| if isinstance(module, nn.Embedding): | |
| normal_(module.weight.data) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| if isinstance(module, MultiheadAttention): | |
| normal_(module.q_proj.weight.data) | |
| normal_(module.k_proj.weight.data) | |
| normal_(module.v_proj.weight.data) | |