|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import XLMRobertaModel |
|
from typing import Optional |
|
import logging |
|
import os |
|
import json |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
SUPPORTED_LANGUAGES = { |
|
'en': 0, 'ru': 1, 'tr': 2, 'es': 3, |
|
'fr': 4, 'it': 5, 'pt': 6 |
|
} |
|
|
|
def validate_lang_ids(lang_ids): |
|
if not isinstance(lang_ids, torch.Tensor): |
|
lang_ids = torch.tensor(lang_ids, dtype=torch.long) |
|
|
|
return torch.clamp(lang_ids, min=0, max=len(SUPPORTED_LANGUAGES)-1) |
|
|
|
class LanguageAwareClassifier(nn.Module): |
|
def __init__(self, hidden_size=1024, num_labels=6): |
|
super().__init__() |
|
self.lang_embed = nn.Embedding(7, 64) |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
nn.Linear(hidden_size + 64, 512), |
|
nn.LayerNorm(512), |
|
nn.GELU(), |
|
nn.Linear(512, num_labels) |
|
) |
|
|
|
|
|
self.lang_thresholds = nn.Parameter( |
|
torch.ones(len(SUPPORTED_LANGUAGES), num_labels) |
|
) |
|
|
|
nn.init.normal_(self.lang_thresholds, mean=1.0, std=0.01) |
|
|
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
"""Initialize weights with Xavier uniform""" |
|
for module in self.classifier: |
|
if isinstance(module, nn.Linear): |
|
nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
elif isinstance(module, nn.LayerNorm): |
|
nn.init.constant_(module.bias, 0) |
|
nn.init.constant_(module.weight, 1.0) |
|
|
|
def forward(self, x, lang_ids): |
|
|
|
if not isinstance(lang_ids, torch.Tensor): |
|
lang_ids = torch.tensor(lang_ids, dtype=torch.long, device=x.device) |
|
elif lang_ids.dtype != torch.long: |
|
lang_ids = lang_ids.long() |
|
|
|
|
|
lang_emb = self.lang_embed(lang_ids) |
|
|
|
|
|
combined = torch.cat([x, lang_emb], dim=-1) |
|
|
|
|
|
logits = self.classifier(combined) |
|
|
|
|
|
thresholds = self.lang_thresholds[lang_ids] |
|
logits = logits * torch.sigmoid(thresholds) |
|
|
|
return logits |
|
|
|
class WeightedBCEWithLogitsLoss(nn.Module): |
|
def __init__(self, gamma=2.0, reduction='mean'): |
|
super().__init__() |
|
self.gamma = gamma |
|
self.reduction = reduction |
|
|
|
def forward(self, logits, targets, weights=None): |
|
bce_loss = F.binary_cross_entropy_with_logits( |
|
logits, targets, reduction='none' |
|
) |
|
pt = torch.exp(-bce_loss) |
|
focal_loss = (1 - pt)**self.gamma * bce_loss |
|
if weights is not None: |
|
focal_loss *= weights |
|
return focal_loss.mean() |
|
|
|
class LanguageAwareTransformer(nn.Module): |
|
def __init__( |
|
self, |
|
num_labels: int = 6, |
|
hidden_size: int = 1024, |
|
num_attention_heads: int = 16, |
|
model_name: str = "xlm-roberta-large", |
|
dropout: float = 0.0 |
|
): |
|
super().__init__() |
|
|
|
|
|
if not SUPPORTED_LANGUAGES: |
|
raise ValueError("No supported languages defined") |
|
logger.info(f"Initializing model with {len(SUPPORTED_LANGUAGES)} supported languages: {list(SUPPORTED_LANGUAGES.keys())}") |
|
|
|
|
|
self.base_model = XLMRobertaModel.from_pretrained(model_name) |
|
self.config = self.base_model.config |
|
|
|
|
|
self.original_hidden_size = self.config.hidden_size |
|
self.needs_projection = hidden_size != self.original_hidden_size |
|
if self.needs_projection: |
|
self.dim_projection = nn.Sequential( |
|
nn.Linear(self.original_hidden_size, hidden_size), |
|
nn.LayerNorm(hidden_size), |
|
nn.GELU() |
|
) |
|
|
|
|
|
self.working_hidden_size = hidden_size if self.needs_projection else self.original_hidden_size |
|
|
|
|
|
num_languages = len(SUPPORTED_LANGUAGES) |
|
self.lang_embed = nn.Embedding(num_languages, 64) |
|
|
|
|
|
self.register_buffer('valid_lang_ids', torch.arange(num_languages)) |
|
|
|
|
|
self.lang_proj = nn.Sequential( |
|
nn.Linear(64, num_attention_heads * hidden_size // num_attention_heads), |
|
nn.LayerNorm(num_attention_heads * hidden_size // num_attention_heads), |
|
nn.Tanh() |
|
) |
|
|
|
|
|
head_dim = hidden_size // num_attention_heads |
|
self.scale = head_dim ** -0.5 |
|
|
|
self.q_proj = nn.Linear(hidden_size, hidden_size) |
|
self.k_proj = nn.Linear(hidden_size, hidden_size) |
|
self.v_proj = nn.Linear(hidden_size, hidden_size) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.post_attention = nn.Sequential( |
|
nn.Linear(hidden_size, hidden_size), |
|
nn.LayerNorm(hidden_size), |
|
nn.GELU() |
|
) |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
nn.Linear(hidden_size, 512), |
|
nn.LayerNorm(512), |
|
nn.GELU(), |
|
nn.Linear(512, num_labels) |
|
) |
|
|
|
self._init_weights() |
|
self.gradient_checkpointing = False |
|
|
|
def _init_weights(self): |
|
"""Initialize weights with careful scaling""" |
|
for module in [self.lang_proj, self.q_proj, self.k_proj, self.v_proj, |
|
self.post_attention, self.classifier]: |
|
if isinstance(module, nn.Sequential): |
|
for layer in module: |
|
if isinstance(layer, nn.Linear): |
|
|
|
if layer in [self.q_proj, self.k_proj, self.v_proj]: |
|
nn.init.normal_(layer.weight, std=0.02) |
|
else: |
|
nn.init.xavier_uniform_(layer.weight) |
|
if layer.bias is not None: |
|
nn.init.zeros_(layer.bias) |
|
elif isinstance(layer, nn.LayerNorm): |
|
nn.init.ones_(layer.weight) |
|
nn.init.zeros_(layer.bias) |
|
elif isinstance(module, nn.Linear): |
|
if module in [self.q_proj, self.k_proj, self.v_proj]: |
|
nn.init.normal_(module.weight, std=0.02) |
|
else: |
|
nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
|
|
def gradient_checkpointing_enable(self): |
|
self.gradient_checkpointing = True |
|
self.base_model.gradient_checkpointing_enable() |
|
|
|
def gradient_checkpointing_disable(self): |
|
self.gradient_checkpointing = False |
|
self.base_model.gradient_checkpointing_disable() |
|
|
|
def validate_lang_ids(self, lang_ids: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Validate and normalize language IDs |
|
Args: |
|
lang_ids: Tensor of language IDs |
|
Returns: |
|
Validated and normalized language ID tensor |
|
Raises: |
|
ValueError if too many invalid IDs detected |
|
""" |
|
if not isinstance(lang_ids, torch.Tensor): |
|
lang_ids = torch.tensor(lang_ids, dtype=torch.long, device=self.valid_lang_ids.device) |
|
elif lang_ids.dtype != torch.long: |
|
lang_ids = lang_ids.long() |
|
|
|
|
|
invalid_mask = ~torch.isin(lang_ids, self.valid_lang_ids) |
|
num_invalid = invalid_mask.sum().item() |
|
|
|
if num_invalid > 0: |
|
invalid_ratio = num_invalid / lang_ids.numel() |
|
if invalid_ratio > 0.1: |
|
raise ValueError( |
|
f"Too many invalid language IDs detected ({num_invalid} out of {lang_ids.numel()}). " |
|
f"Valid range is 0-{len(SUPPORTED_LANGUAGES)-1}" |
|
) |
|
|
|
logger.warning( |
|
f"Found {num_invalid} invalid language IDs. " |
|
f"Valid range is 0-{len(SUPPORTED_LANGUAGES)-1}. " |
|
"Invalid IDs will be clamped to valid range." |
|
) |
|
lang_ids = torch.clamp(lang_ids, min=0, max=len(SUPPORTED_LANGUAGES)-1) |
|
|
|
return lang_ids |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
labels: Optional[torch.Tensor] = None, |
|
lang_ids: Optional[torch.Tensor] = None, |
|
mode: str = 'train' |
|
) -> dict: |
|
device = input_ids.device |
|
batch_size = input_ids.size(0) |
|
|
|
|
|
if lang_ids is None: |
|
lang_ids = torch.zeros(batch_size, dtype=torch.long, device=device) |
|
|
|
|
|
try: |
|
lang_ids = self.validate_lang_ids(lang_ids) |
|
except ValueError as e: |
|
logger.error(f"Language ID validation failed: {str(e)}") |
|
logger.error("Falling back to default language (0)") |
|
lang_ids = torch.zeros_like(lang_ids) |
|
|
|
|
|
hidden_states = self.base_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
).last_hidden_state |
|
|
|
|
|
if hidden_states.isnan().any(): |
|
raise ValueError("NaN detected in hidden states") |
|
if hidden_states.isinf().any(): |
|
raise ValueError("Inf detected in hidden states") |
|
|
|
|
|
if self.needs_projection: |
|
hidden_states = self.dim_projection(hidden_states) |
|
|
|
|
|
lang_emb = self.lang_embed(lang_ids) |
|
lang_bias = self.lang_proj(lang_emb) |
|
|
|
|
|
batch_size, seq_len, hidden_size = hidden_states.shape |
|
num_heads = self.config.num_attention_heads |
|
head_dim = hidden_size // num_heads |
|
|
|
|
|
q = self.q_proj(hidden_states).view(batch_size, seq_len, num_heads, head_dim) |
|
k = self.k_proj(hidden_states).view(batch_size, seq_len, num_heads, head_dim) |
|
v = self.v_proj(hidden_states).view(batch_size, seq_len, num_heads, head_dim) |
|
|
|
|
|
q = q.transpose(1, 2) |
|
k = k.transpose(1, 2) |
|
v = v.transpose(1, 2) |
|
|
|
|
|
attn_bias = lang_bias.view(batch_size, num_heads, head_dim, 1) |
|
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
|
attn_scores = attn_scores + torch.matmul(q, attn_bias).squeeze(-1).unsqueeze(-1) |
|
|
|
|
|
if attention_mask is not None: |
|
attn_scores = attn_scores.masked_fill( |
|
~attention_mask.bool().unsqueeze(1).unsqueeze(2), |
|
float('-inf') |
|
) |
|
|
|
|
|
attn_weights = F.softmax(attn_scores, dim=-1) |
|
attn_weights = self.dropout(attn_weights) |
|
attention_output = torch.matmul(attn_weights, v) |
|
|
|
|
|
attention_output = attention_output.transpose(1, 2).contiguous().view( |
|
batch_size, seq_len, hidden_size |
|
) |
|
output = self.post_attention(attention_output) |
|
|
|
|
|
logits = self.classifier(output[:, 0]) |
|
|
|
|
|
LANG_THRESHOLD_ADJUSTMENTS = { |
|
0: [0.00, 0.00, 0.00, 0.00, 0.00, 0.00], |
|
1: [-0.02, 0.00, 0.02, 0.00, -0.03, 0.00], |
|
2: [-0.02, 0.00, 0.02, 0.00, -0.03, 0.00], |
|
3: [-0.02, 0.00, 0.02, 0.00, -0.03, 0.00], |
|
4: [-0.02, 0.00, 0.02, 0.00, -0.03, 0.00], |
|
5: [-0.02, 0.00, 0.02, 0.00, -0.03, 0.00], |
|
6: [-0.02, 0.00, 0.02, 0.00, -0.03, 0.00], |
|
} |
|
|
|
|
|
if mode == 'inference': |
|
threshold_adj = torch.tensor( |
|
[LANG_THRESHOLD_ADJUSTMENTS[lang.item()] for lang in lang_ids], |
|
device=logits.device |
|
) |
|
|
|
logits = logits + threshold_adj |
|
|
|
probabilities = torch.sigmoid(logits) |
|
|
|
|
|
result = { |
|
'logits': logits, |
|
'probabilities': probabilities |
|
} |
|
|
|
|
|
if labels is not None: |
|
loss_fct = WeightedBCEWithLogitsLoss() |
|
result['loss'] = loss_fct(logits, labels) |
|
|
|
return result |
|
|
|
def save_pretrained(self, save_path: str): |
|
os.makedirs(save_path, exist_ok=True) |
|
torch.save(self.state_dict(), os.path.join(save_path, 'pytorch_model.bin')) |
|
|
|
config_dict = { |
|
'num_labels': self.classifier[-1].out_features, |
|
'hidden_size': self.config.hidden_size, |
|
'num_attention_heads': self.config.num_attention_heads, |
|
'model_name': self.config.name_or_path, |
|
'dropout': self.dropout.p |
|
} |
|
|
|
with open(os.path.join(save_path, 'config.json'), 'w') as f: |
|
json.dump(config_dict, f, indent=2) |
|
|