Spaces:
Paused
Paused
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta | |
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
__all__ = ['XLMRoberta', 'xlm_roberta_large'] | |
class SelfAttention(nn.Module): | |
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): | |
assert dim % num_heads == 0 | |
super().__init__() | |
self.dim = dim | |
self.num_heads = num_heads | |
self.head_dim = dim // num_heads | |
self.eps = eps | |
# layers | |
self.q = nn.Linear(dim, dim) | |
self.k = nn.Linear(dim, dim) | |
self.v = nn.Linear(dim, dim) | |
self.o = nn.Linear(dim, dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x, mask): | |
""" | |
x: [B, L, C]. | |
""" | |
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim | |
# compute query, key, value | |
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) | |
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) | |
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) | |
# compute attention | |
p = self.dropout.p if self.training else 0.0 | |
x = F.scaled_dot_product_attention(q, k, v, mask, p) | |
x = x.permute(0, 2, 1, 3).reshape(b, s, c) | |
# output | |
x = self.o(x) | |
x = self.dropout(x) | |
return x | |
class AttentionBlock(nn.Module): | |
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): | |
super().__init__() | |
self.dim = dim | |
self.num_heads = num_heads | |
self.post_norm = post_norm | |
self.eps = eps | |
# layers | |
self.attn = SelfAttention(dim, num_heads, dropout, eps) | |
self.norm1 = nn.LayerNorm(dim, eps=eps) | |
self.ffn = nn.Sequential( | |
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), | |
nn.Dropout(dropout)) | |
self.norm2 = nn.LayerNorm(dim, eps=eps) | |
def forward(self, x, mask): | |
if self.post_norm: | |
x = self.norm1(x + self.attn(x, mask)) | |
x = self.norm2(x + self.ffn(x)) | |
else: | |
x = x + self.attn(self.norm1(x), mask) | |
x = x + self.ffn(self.norm2(x)) | |
return x | |
class XLMRoberta(nn.Module): | |
""" | |
XLMRobertaModel with no pooler and no LM head. | |
""" | |
def __init__(self, | |
vocab_size=250002, | |
max_seq_len=514, | |
type_size=1, | |
pad_id=1, | |
dim=1024, | |
num_heads=16, | |
num_layers=24, | |
post_norm=True, | |
dropout=0.1, | |
eps=1e-5): | |
super().__init__() | |
self.vocab_size = vocab_size | |
self.max_seq_len = max_seq_len | |
self.type_size = type_size | |
self.pad_id = pad_id | |
self.dim = dim | |
self.num_heads = num_heads | |
self.num_layers = num_layers | |
self.post_norm = post_norm | |
self.eps = eps | |
# embeddings | |
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) | |
self.type_embedding = nn.Embedding(type_size, dim) | |
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) | |
self.dropout = nn.Dropout(dropout) | |
# blocks | |
self.blocks = nn.ModuleList([ | |
AttentionBlock(dim, num_heads, post_norm, dropout, eps) | |
for _ in range(num_layers) | |
]) | |
# norm layer | |
self.norm = nn.LayerNorm(dim, eps=eps) | |
def forward(self, ids): | |
""" | |
ids: [B, L] of torch.LongTensor. | |
""" | |
b, s = ids.shape | |
mask = ids.ne(self.pad_id).long() | |
# embeddings | |
x = self.token_embedding(ids) + \ | |
self.type_embedding(torch.zeros_like(ids)) + \ | |
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) | |
if self.post_norm: | |
x = self.norm(x) | |
x = self.dropout(x) | |
# blocks | |
mask = torch.where( | |
mask.view(b, 1, 1, s).gt(0), 0.0, | |
torch.finfo(x.dtype).min) | |
for block in self.blocks: | |
x = block(x, mask) | |
# output | |
if not self.post_norm: | |
x = self.norm(x) | |
return x | |
def xlm_roberta_large(pretrained=False, | |
return_tokenizer=False, | |
device='cpu', | |
**kwargs): | |
""" | |
XLMRobertaLarge adapted from Huggingface. | |
""" | |
# params | |
cfg = dict( | |
vocab_size=250002, | |
max_seq_len=514, | |
type_size=1, | |
pad_id=1, | |
dim=1024, | |
num_heads=16, | |
num_layers=24, | |
post_norm=True, | |
dropout=0.1, | |
eps=1e-5) | |
cfg.update(**kwargs) | |
# init a model on device | |
with torch.device(device): | |
model = XLMRoberta(**cfg) | |
return model | |