|
|
"""PyTorch TextSyncMimi model - Text-synchronous neural audio codec based on Mimi.""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional, Dict, List, Union |
|
|
|
|
|
try: |
|
|
from .configuration_mimi import MimiConfig |
|
|
from .configuration_text_sync_mimi import TextSyncMimiConfig |
|
|
from .modeling_mimi_clean import MimiPreTrainedModel, MimiModel |
|
|
from .modeling_backbone_components import ( |
|
|
CrossAttentionTransformer, |
|
|
CausalAttentionTransformer |
|
|
) |
|
|
except ImportError: |
|
|
from configuration_mimi import MimiConfig |
|
|
from configuration_text_sync_mimi import TextSyncMimiConfig |
|
|
from modeling_mimi_clean import MimiPreTrainedModel, MimiModel |
|
|
from modeling_backbone_components import ( |
|
|
CrossAttentionTransformer, |
|
|
CausalAttentionTransformer |
|
|
) |
|
|
|
|
|
|
|
|
class TextSyncMimi(MimiPreTrainedModel): |
|
|
""" |
|
|
TextSyncMimi: Text-Synchronous Neural Audio Codec Model |
|
|
|
|
|
A neural audio codec model that combines text and speech representations for |
|
|
high-quality text-to-speech synthesis. Features: |
|
|
|
|
|
- Learnable text embeddings |
|
|
- Cross-attention transformer for text-speech alignment |
|
|
- Autoregressive transformer for causal speech generation |
|
|
- BCE-based end token prediction for dynamic duration control |
|
|
|
|
|
Architecture: |
|
|
- Text Embedding Layer: Maps token IDs to 4,096-dim embeddings |
|
|
- Mimi Encoder: Pre-trained audio encoder (frozen) |
|
|
- Text Projection: Linear projection from 4,096 to 512 dimensions |
|
|
- Cross-Attention Transformer: Aligns text with speech features |
|
|
- Autoregressive Transformer: Generates speech representations |
|
|
- End Token Classifier: Predicts when to stop generating |
|
|
""" |
|
|
|
|
|
config_class = TextSyncMimiConfig |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: Optional[Union[MimiConfig, 'TextSyncMimiConfig']] = None, |
|
|
model_id: Optional[str] = None, |
|
|
token: Optional[str] = None, |
|
|
alpha: Optional[float] = None, |
|
|
cross_attention_layers: Optional[int] = None, |
|
|
causal_attention_layers: Optional[int] = None, |
|
|
bce_threshold: Optional[float] = None, |
|
|
vocab_size: Optional[int] = None, |
|
|
): |
|
|
""" |
|
|
Initialize TextSyncMimi model. |
|
|
|
|
|
Args: |
|
|
config: Model configuration (TextSyncMimiConfig or MimiConfig) |
|
|
model_id: Mimi model ID (e.g., "kyutai/mimi"). If None, uses config.mimi_model_id |
|
|
token: Hugging Face authentication token |
|
|
alpha: Weight for BCE end token loss. If None, uses config.alpha |
|
|
cross_attention_layers: Number of cross-attention layers. If None, uses config |
|
|
causal_attention_layers: Number of autoregressive layers. If None, uses config |
|
|
bce_threshold: BCE loss threshold. If None, uses config.bce_threshold |
|
|
vocab_size: Text vocabulary size. If None, uses config.vocab_size |
|
|
""" |
|
|
|
|
|
if config is None: |
|
|
if model_id is None: |
|
|
raise ValueError("Either config or model_id must be provided") |
|
|
config = MimiConfig.from_pretrained(model_id, token=token) |
|
|
|
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
if hasattr(config, 'mimi_model_id'): |
|
|
model_id = model_id or config.mimi_model_id |
|
|
if model_id is None: |
|
|
raise ValueError("model_id must be provided either as argument or in config.mimi_model_id") |
|
|
|
|
|
alpha = alpha if alpha is not None else getattr(config, 'alpha', 1.0) |
|
|
cross_attention_layers = cross_attention_layers if cross_attention_layers is not None else getattr(config, 'cross_attention_layers', 2) |
|
|
causal_attention_layers = causal_attention_layers if causal_attention_layers is not None else getattr(config, 'causal_attention_layers', 2) |
|
|
bce_threshold = bce_threshold if bce_threshold is not None else getattr(config, 'bce_threshold', 0.1) |
|
|
vocab_size = vocab_size if vocab_size is not None else getattr(config, 'vocab_size', 128256) |
|
|
|
|
|
|
|
|
self.config = config |
|
|
model = MimiModel.from_pretrained(model_id, token=token) |
|
|
|
|
|
|
|
|
self.alpha = alpha |
|
|
self.bce_threshold = bce_threshold |
|
|
|
|
|
|
|
|
self.text_token_embedding = nn.Embedding(vocab_size, 4096) |
|
|
|
|
|
|
|
|
self.text_proj = nn.Linear(4096, 512) |
|
|
|
|
|
|
|
|
cross_attention_config = MimiConfig(**self.config.__dict__) |
|
|
cross_attention_config.num_hidden_layers = cross_attention_layers |
|
|
cross_attention_config.hidden_size = 512 |
|
|
self.cross_attention_transformer = CrossAttentionTransformer(cross_attention_config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal_attention_config = MimiConfig(**self.config.__dict__) |
|
|
causal_attention_config.num_hidden_layers = causal_attention_layers |
|
|
causal_attention_config.hidden_size = 512 |
|
|
self.ar_transformer = CausalAttentionTransformer(causal_attention_config) |
|
|
|
|
|
|
|
|
self.text_speech_latent_embed = nn.Embedding(1, 512) |
|
|
self.time_speech_start_embed = nn.Embedding(1, 512) |
|
|
self.time_speech_end_embed = nn.Embedding(1, 512) |
|
|
|
|
|
|
|
|
self.end_token_classifier = nn.Linear(512, 1) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
|
self.encoder = model.encoder |
|
|
self.encoder_transformer = model.encoder_transformer |
|
|
self.quantizer = model.quantizer |
|
|
self.downsample = model.downsample |
|
|
self.upsample = model.upsample |
|
|
|
|
|
|
|
|
self._print_subnetwork_parameter_counts() |
|
|
|
|
|
def initialize_text_embeddings_from_weights(self, embedding_weight: torch.Tensor) -> None: |
|
|
""" |
|
|
Initialize text embeddings from a weight matrix. |
|
|
|
|
|
Args: |
|
|
embedding_weight: Weight matrix of shape (vocab_size, 4096) |
|
|
""" |
|
|
if embedding_weight.dim() != 2 or embedding_weight.size(1) != 4096: |
|
|
raise ValueError("embedding_weight must have shape (vocab_size, 4096)") |
|
|
if embedding_weight.size(0) != self.text_token_embedding.num_embeddings: |
|
|
raise ValueError("Provided vocab_size does not match model's text_token_embedding") |
|
|
with torch.no_grad(): |
|
|
self.text_token_embedding.weight.copy_(embedding_weight) |
|
|
for p in self.text_token_embedding.parameters(): |
|
|
p.requires_grad = True |
|
|
|
|
|
def initialize_text_embeddings_from_llama(self, llama_embeddings_module: torch.nn.Module) -> None: |
|
|
""" |
|
|
Initialize text embeddings from a LLaMA embedding module. |
|
|
|
|
|
Args: |
|
|
llama_embeddings_module: LLaMA embedding module with weight shape (vocab_size, 4096) |
|
|
""" |
|
|
if not hasattr(llama_embeddings_module, 'weight'): |
|
|
raise ValueError("llama_embeddings_module must have a 'weight' attribute") |
|
|
weight = llama_embeddings_module.weight.data |
|
|
self.initialize_text_embeddings_from_weights(weight) |
|
|
|
|
|
def _print_subnetwork_parameter_counts(self) -> None: |
|
|
"""Print parameter counts for model subnetworks.""" |
|
|
print("=" * 70) |
|
|
print("TextSyncMimi Parameter Counts") |
|
|
print("=" * 70) |
|
|
print(f"Encoder: {sum(p.numel() for p in self.encoder.parameters()) / 1e6:.2f}M") |
|
|
print(f"Encoder Transformer: {sum(p.numel() for p in self.encoder_transformer.parameters()) / 1e6:.2f}M") |
|
|
print(f"Cross-Attention Transformer: {sum(p.numel() for p in self.cross_attention_transformer.parameters()) / 1e6:.2f}M") |
|
|
print(f"AR Transformer: {sum(p.numel() for p in self.ar_transformer.parameters()) / 1e6:.2f}M") |
|
|
print(f"Quantizer: {sum(p.numel() for p in self.quantizer.parameters()) / 1e6:.2f}M") |
|
|
print("=" * 70) |
|
|
|
|
|
def encode_audio_to_representation( |
|
|
self, |
|
|
input_values: torch.Tensor, |
|
|
audio_attention_mask: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Encode audio to speech representation. |
|
|
|
|
|
Args: |
|
|
input_values: Audio waveform (B, 1, audio_len) |
|
|
audio_attention_mask: Attention mask (B, audio_len) |
|
|
|
|
|
Returns: |
|
|
Speech embeddings (B, 512, 12.5 * T) |
|
|
""" |
|
|
batch_size = input_values.shape[0] |
|
|
device = input_values.device |
|
|
|
|
|
|
|
|
embeddings = self.encoder(input_values) |
|
|
encoder_outputs = self.encoder_transformer(embeddings.transpose(1, 2)) |
|
|
embeddings = encoder_outputs[0].transpose(1, 2) |
|
|
embeddings = self.downsample(embeddings) |
|
|
|
|
|
|
|
|
if audio_attention_mask is not None: |
|
|
speech_seq_len = embeddings.shape[-1] |
|
|
speech_attention_mask = torch.zeros(batch_size, speech_seq_len, device=device, dtype=torch.bool) |
|
|
|
|
|
for b in range(batch_size): |
|
|
actual_audio_len = audio_attention_mask[b].sum().item() |
|
|
actual_speech_len = int(actual_audio_len * 12.5 / 24000) |
|
|
actual_speech_len = min(actual_speech_len, speech_seq_len) |
|
|
if actual_speech_len > 0: |
|
|
speech_attention_mask[b, :actual_speech_len] = True |
|
|
|
|
|
speech_mask_expanded = speech_attention_mask.unsqueeze(1) |
|
|
embeddings = embeddings * speech_mask_expanded.float() |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def generate_autoregressive( |
|
|
self, |
|
|
text_token_ids: torch.LongTensor, |
|
|
input_values: Optional[torch.Tensor] = None, |
|
|
speech_embeddings: Optional[torch.Tensor] = None, |
|
|
audio_attention_mask: Optional[torch.Tensor] = None, |
|
|
speech_attention_mask: Optional[torch.Tensor] = None, |
|
|
text_attention_mask: Optional[torch.Tensor] = None, |
|
|
max_z_tokens: int = 50, |
|
|
end_token_threshold: float = 0.5, |
|
|
device: Optional[torch.device] = None, |
|
|
) -> List[List[torch.Tensor]]: |
|
|
""" |
|
|
Generate audio autoregressively. |
|
|
|
|
|
Args: |
|
|
text_token_ids: Text token IDs (B, L) |
|
|
input_values: Audio input (B, 1, 24000 * T) - for normal mode |
|
|
speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode |
|
|
audio_attention_mask: Audio mask (B, audio_seq_len) - for normal mode |
|
|
speech_attention_mask: Speech mask (B, speech_seq_len) - for cached mode |
|
|
text_attention_mask: Text mask (B, text_seq_len) |
|
|
max_z_tokens: Maximum z tokens per text position |
|
|
end_token_threshold: Probability threshold for stopping |
|
|
device: Device for computation |
|
|
|
|
|
Returns: |
|
|
List of z_tokens lists (one per batch item) |
|
|
""" |
|
|
if device is None: |
|
|
device = text_token_ids.device |
|
|
|
|
|
self.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
if speech_embeddings is not None: |
|
|
|
|
|
|
|
|
pass |
|
|
else: |
|
|
|
|
|
if input_values is None: |
|
|
raise ValueError("Either input_values or speech_embeddings must be provided") |
|
|
speech_embeddings = self.encode_audio_to_representation( |
|
|
input_values, |
|
|
audio_attention_mask=audio_attention_mask |
|
|
) |
|
|
speech_embeddings = speech_embeddings.transpose(1, 2) |
|
|
|
|
|
|
|
|
text_embeddings_4096 = self.text_token_embedding(text_token_ids) |
|
|
text_embeddings_proj = self.text_proj(text_embeddings_4096) |
|
|
|
|
|
|
|
|
|
|
|
formatted_text_attention_mask = None |
|
|
formatted_speech_attention_mask = None |
|
|
|
|
|
batch_size, text_seq_len = text_embeddings_proj.shape[:2] |
|
|
|
|
|
if text_attention_mask is not None: |
|
|
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype)) |
|
|
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) |
|
|
padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len) |
|
|
combined_mask = causal_mask * padding_mask |
|
|
formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf')) |
|
|
else: |
|
|
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype)) |
|
|
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) |
|
|
formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf')) |
|
|
|
|
|
|
|
|
if speech_attention_mask is not None: |
|
|
|
|
|
speech_seq_len = speech_embeddings.shape[1] |
|
|
speech_mask = speech_attention_mask.bool() |
|
|
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len) |
|
|
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf')) |
|
|
elif audio_attention_mask is not None: |
|
|
|
|
|
speech_seq_len = speech_embeddings.shape[1] |
|
|
speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=device) |
|
|
for b in range(batch_size): |
|
|
audio_len = audio_attention_mask[b].sum().item() |
|
|
speech_len = int(audio_len * 12.5 / 24000) |
|
|
speech_len = min(speech_len, speech_seq_len) |
|
|
speech_mask[b, :speech_len] = True |
|
|
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len) |
|
|
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf')) |
|
|
else: |
|
|
formatted_speech_attention_mask = None |
|
|
|
|
|
|
|
|
cross_attention_outputs = self.cross_attention_transformer( |
|
|
hidden_states=text_embeddings_proj, |
|
|
encoder_hidden_states=speech_embeddings, |
|
|
attention_mask=formatted_text_attention_mask, |
|
|
encoder_attention_mask=formatted_speech_attention_mask, |
|
|
alignment_chunk_sizes=None, |
|
|
) |
|
|
cross_attention_outputs = cross_attention_outputs.last_hidden_state |
|
|
|
|
|
|
|
|
text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) |
|
|
time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) |
|
|
time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) |
|
|
|
|
|
generated_z_tokens = [] |
|
|
|
|
|
|
|
|
for b in range(batch_size): |
|
|
|
|
|
if text_attention_mask is not None: |
|
|
valid_text_len = text_attention_mask[b].sum().item() |
|
|
else: |
|
|
valid_text_len = text_embeddings_proj.shape[1] |
|
|
|
|
|
|
|
|
sequence = [text_speech_latent_emb] |
|
|
batch_z_tokens = [] |
|
|
|
|
|
|
|
|
for i in range(valid_text_len): |
|
|
|
|
|
t_i = text_embeddings_proj[b, i:i+1] |
|
|
s_i = cross_attention_outputs[b, i:i+1] |
|
|
sequence.extend([t_i, s_i]) |
|
|
|
|
|
|
|
|
sequence.append(time_speech_start_emb) |
|
|
|
|
|
|
|
|
z_count = 0 |
|
|
while z_count < max_z_tokens: |
|
|
|
|
|
current_sequence = torch.cat(sequence, dim=0).unsqueeze(0) |
|
|
|
|
|
|
|
|
seq_len = current_sequence.shape[1] |
|
|
ar_attention_mask = torch.ones(1, seq_len, dtype=torch.bool, device=device) |
|
|
|
|
|
|
|
|
ar_outputs = self.ar_transformer( |
|
|
hidden_states=current_sequence, |
|
|
attention_mask=ar_attention_mask, |
|
|
) |
|
|
|
|
|
|
|
|
last_prediction = ar_outputs.last_hidden_state[0, -1:, :] |
|
|
|
|
|
|
|
|
end_token_logit = self.end_token_classifier(last_prediction).squeeze(-1) |
|
|
end_token_prob = torch.sigmoid(end_token_logit).item() |
|
|
|
|
|
|
|
|
if end_token_prob >= end_token_threshold: |
|
|
|
|
|
break |
|
|
else: |
|
|
|
|
|
sequence.append(last_prediction) |
|
|
batch_z_tokens.append(last_prediction.squeeze(0)) |
|
|
z_count += 1 |
|
|
|
|
|
|
|
|
sequence.append(time_speech_end_emb) |
|
|
|
|
|
|
|
|
generated_z_tokens.append(batch_z_tokens) |
|
|
|
|
|
return generated_z_tokens |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
text_token_ids: torch.LongTensor, |
|
|
input_values: Optional[torch.Tensor] = None, |
|
|
speech_embeddings: Optional[torch.Tensor] = None, |
|
|
alignment_chunk_sizes: torch.Tensor = None, |
|
|
audio_attention_mask: Optional[torch.Tensor] = None, |
|
|
speech_attention_mask: Optional[torch.Tensor] = None, |
|
|
text_attention_mask: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward pass for training. |
|
|
|
|
|
Args: |
|
|
text_token_ids: Text token IDs (B, L) |
|
|
input_values: Audio input (B, 1, 24000 * T) - for normal mode |
|
|
speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode |
|
|
alignment_chunk_sizes: Alignment chunk sizes (B, L) |
|
|
audio_attention_mask: Audio mask (B, audio_seq_len) |
|
|
speech_attention_mask: Speech mask (B, speech_seq_len) |
|
|
text_attention_mask: Text mask (B, text_seq_len) |
|
|
|
|
|
Returns: |
|
|
Dictionary with 'loss', 'reconstruction_loss', and 'bce_end_token_loss' |
|
|
""" |
|
|
|
|
|
if speech_embeddings is not None: |
|
|
pass |
|
|
elif input_values is not None: |
|
|
|
|
|
|
|
|
speech_embeddings_raw = self.encode_audio_to_representation( |
|
|
input_values, |
|
|
audio_attention_mask |
|
|
) |
|
|
|
|
|
|
|
|
speech_embeddings = speech_embeddings_raw.transpose(1, 2) |
|
|
else: |
|
|
raise ValueError("Either input_values or speech_embeddings must be provided") |
|
|
|
|
|
text_embeddings_4096 = self.text_token_embedding(text_token_ids) |
|
|
text_embeddings = self.text_proj(text_embeddings_4096) |
|
|
|
|
|
|
|
|
formatted_text_attention_mask = None |
|
|
formatted_speech_attention_mask = None |
|
|
|
|
|
|
|
|
batch_size, text_seq_len = text_embeddings.shape[:2] |
|
|
|
|
|
if text_attention_mask is not None: |
|
|
|
|
|
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype)) |
|
|
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) |
|
|
|
|
|
|
|
|
padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len) |
|
|
combined_mask = causal_mask * padding_mask |
|
|
|
|
|
|
|
|
formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf')) |
|
|
else: |
|
|
|
|
|
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype)) |
|
|
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) |
|
|
formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf')) |
|
|
|
|
|
|
|
|
|
|
|
if speech_attention_mask is not None: |
|
|
|
|
|
speech_seq_len = speech_embeddings.shape[1] |
|
|
speech_mask = speech_attention_mask.bool() |
|
|
|
|
|
|
|
|
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len) |
|
|
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf')) |
|
|
elif audio_attention_mask is not None: |
|
|
|
|
|
speech_seq_len = speech_embeddings.shape[1] |
|
|
|
|
|
|
|
|
speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=speech_embeddings.device) |
|
|
|
|
|
for b in range(batch_size): |
|
|
audio_len = audio_attention_mask[b].sum().item() |
|
|
speech_len = int(audio_len * 12.5 / 24000) |
|
|
speech_len = min(speech_len, speech_seq_len) |
|
|
speech_mask[b, :speech_len] = True |
|
|
|
|
|
|
|
|
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len) |
|
|
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf')) |
|
|
else: |
|
|
|
|
|
formatted_speech_attention_mask = None |
|
|
|
|
|
|
|
|
|
|
|
cross_attention_outputs = self.cross_attention_transformer( |
|
|
hidden_states=text_embeddings, |
|
|
encoder_hidden_states=speech_embeddings, |
|
|
attention_mask=formatted_text_attention_mask, |
|
|
encoder_attention_mask=formatted_speech_attention_mask, |
|
|
alignment_chunk_sizes=None, |
|
|
) |
|
|
cross_attention_outputs = cross_attention_outputs.last_hidden_state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
embeddings_bct = speech_embeddings.transpose(1, 2) |
|
|
codes_kbt = self.quantizer.encode(embeddings_bct) |
|
|
codes_bkt = codes_kbt.transpose(0, 1) |
|
|
decoder_input_emb = self.quantizer.decode(codes_bkt) |
|
|
target_representation = decoder_input_emb.transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = text_embeddings.device |
|
|
text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) |
|
|
time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) |
|
|
time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) |
|
|
|
|
|
batch_size = text_embeddings.shape[0] |
|
|
interleaved_sequences = [] |
|
|
loss_masks = [] |
|
|
bce_labels_batch = [] |
|
|
bce_masks = [] |
|
|
sequence_lengths = [] |
|
|
all_z_tokens = [] |
|
|
max_total_length = 0 |
|
|
|
|
|
for b in range(batch_size): |
|
|
|
|
|
sequence_parts = [text_speech_latent_emb] |
|
|
loss_mask_parts = [False] |
|
|
bce_label_parts = [0] |
|
|
bce_mask_parts = [False] |
|
|
|
|
|
|
|
|
if text_attention_mask is not None: |
|
|
valid_text_len = text_attention_mask[b].sum().item() |
|
|
else: |
|
|
valid_text_len = text_embeddings.shape[1] |
|
|
|
|
|
|
|
|
speech_position = 0 |
|
|
|
|
|
|
|
|
for i in range(valid_text_len): |
|
|
|
|
|
t_i = text_embeddings[b, i:i+1] |
|
|
sequence_parts.append(t_i) |
|
|
loss_mask_parts.append(False) |
|
|
bce_label_parts.append(0) |
|
|
bce_mask_parts.append(False) |
|
|
|
|
|
|
|
|
s_i = cross_attention_outputs[b, i:i+1] |
|
|
sequence_parts.append(s_i) |
|
|
loss_mask_parts.append(False) |
|
|
bce_label_parts.append(0) |
|
|
bce_mask_parts.append(False) |
|
|
|
|
|
|
|
|
sequence_parts.append(time_speech_start_emb) |
|
|
loss_mask_parts.append(False) |
|
|
bce_label_parts.append(0) |
|
|
bce_mask_parts.append(False) |
|
|
|
|
|
|
|
|
chunk_size = alignment_chunk_sizes[b, i].item() |
|
|
if chunk_size > 0: |
|
|
end_position = speech_position + chunk_size |
|
|
|
|
|
end_position = min(end_position, target_representation.shape[1]) |
|
|
actual_chunk_size = end_position - speech_position |
|
|
|
|
|
if actual_chunk_size > 0: |
|
|
z_tokens = target_representation[b, speech_position:end_position] |
|
|
sequence_parts.append(z_tokens) |
|
|
loss_mask_parts.extend([True] * actual_chunk_size) |
|
|
bce_label_parts.extend([0] * actual_chunk_size) |
|
|
bce_mask_parts.extend([True] * actual_chunk_size) |
|
|
|
|
|
|
|
|
all_z_tokens.append(z_tokens) |
|
|
|
|
|
speech_position = end_position |
|
|
|
|
|
|
|
|
sequence_parts.append(time_speech_end_emb) |
|
|
loss_mask_parts.append(False) |
|
|
bce_label_parts.append(1) |
|
|
bce_mask_parts.append(True) |
|
|
|
|
|
|
|
|
full_sequence = torch.cat(sequence_parts, dim=0) |
|
|
loss_mask = torch.tensor(loss_mask_parts, dtype=torch.bool, device=device) |
|
|
bce_labels = torch.tensor(bce_label_parts, dtype=torch.float, device=device) |
|
|
bce_mask = torch.tensor(bce_mask_parts, dtype=torch.bool, device=device) |
|
|
|
|
|
interleaved_sequences.append(full_sequence) |
|
|
loss_masks.append(loss_mask) |
|
|
bce_labels_batch.append(bce_labels) |
|
|
bce_masks.append(bce_mask) |
|
|
sequence_lengths.append(full_sequence.shape[0]) |
|
|
max_total_length = max(max_total_length, full_sequence.shape[0]) |
|
|
|
|
|
|
|
|
padded_sequences = [] |
|
|
padded_loss_masks = [] |
|
|
padded_bce_labels = [] |
|
|
padded_bce_masks = [] |
|
|
|
|
|
for sequence, loss_mask, bce_labels, bce_mask in zip(interleaved_sequences, loss_masks, bce_labels_batch, bce_masks): |
|
|
current_length = sequence.shape[0] |
|
|
if current_length < max_total_length: |
|
|
padding = torch.zeros(max_total_length - current_length, 512, device=device, dtype=sequence.dtype) |
|
|
padded_sequence = torch.cat([sequence, padding], dim=0) |
|
|
|
|
|
mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device) |
|
|
padded_mask = torch.cat([loss_mask, mask_padding], dim=0) |
|
|
|
|
|
bce_label_padding = torch.zeros(max_total_length - current_length, dtype=torch.float, device=device) |
|
|
padded_bce_label = torch.cat([bce_labels, bce_label_padding], dim=0) |
|
|
|
|
|
bce_mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device) |
|
|
padded_bce_mask = torch.cat([bce_mask, bce_mask_padding], dim=0) |
|
|
else: |
|
|
padded_sequence = sequence |
|
|
padded_mask = loss_mask |
|
|
padded_bce_label = bce_labels |
|
|
padded_bce_mask = bce_mask |
|
|
|
|
|
padded_sequences.append(padded_sequence) |
|
|
padded_loss_masks.append(padded_mask) |
|
|
padded_bce_labels.append(padded_bce_label) |
|
|
padded_bce_masks.append(padded_bce_mask) |
|
|
|
|
|
|
|
|
interleaved_batch = torch.stack(padded_sequences, dim=0) |
|
|
loss_mask_batch = torch.stack(padded_loss_masks, dim=0) |
|
|
bce_labels_batch_tensor = torch.stack(padded_bce_labels, dim=0) |
|
|
bce_mask_batch = torch.stack(padded_bce_masks, dim=0) |
|
|
|
|
|
|
|
|
if max_total_length > 1: |
|
|
ar_input = interleaved_batch[:, :-1, :] |
|
|
ar_targets = interleaved_batch[:, 1:, :] |
|
|
ar_loss_mask = loss_mask_batch[:, 1:] |
|
|
ar_bce_labels = bce_labels_batch_tensor[:, 1:] |
|
|
ar_bce_mask = bce_mask_batch[:, 1:] |
|
|
|
|
|
|
|
|
|
|
|
ar_seq_len = ar_input.shape[1] |
|
|
ar_attention_mask = torch.zeros(batch_size, ar_seq_len, dtype=torch.bool, device=device) |
|
|
for b in range(batch_size): |
|
|
valid_len = min(ar_seq_len, sequence_lengths[b] - 1) |
|
|
if valid_len > 0: |
|
|
ar_attention_mask[b, :valid_len] = True |
|
|
|
|
|
ar_outputs = self.ar_transformer( |
|
|
hidden_states=ar_input, |
|
|
attention_mask=ar_attention_mask, |
|
|
) |
|
|
ar_predictions = ar_outputs.last_hidden_state |
|
|
|
|
|
|
|
|
bce_logits = self.end_token_classifier(ar_predictions).squeeze(-1) |
|
|
|
|
|
|
|
|
if ar_loss_mask.any(): |
|
|
|
|
|
valid_predictions = ar_predictions[ar_loss_mask] |
|
|
valid_targets = ar_targets[ar_loss_mask] |
|
|
|
|
|
|
|
|
reconstruction_loss = nn.functional.mse_loss( |
|
|
valid_predictions, |
|
|
valid_targets, |
|
|
reduction='mean' |
|
|
) |
|
|
else: |
|
|
|
|
|
reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True) |
|
|
|
|
|
|
|
|
if ar_bce_mask.any(): |
|
|
|
|
|
valid_bce_logits = bce_logits[ar_bce_mask] |
|
|
valid_bce_labels = ar_bce_labels[ar_bce_mask] |
|
|
|
|
|
|
|
|
bce_end_token_loss = nn.functional.binary_cross_entropy_with_logits( |
|
|
valid_bce_logits, |
|
|
valid_bce_labels, |
|
|
reduction='mean' |
|
|
) |
|
|
else: |
|
|
|
|
|
bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True) |
|
|
|
|
|
if self.bce_threshold > 0.0: |
|
|
clamped_bce_loss = torch.clamp(bce_end_token_loss - self.bce_threshold, min=0.0) |
|
|
total_loss = reconstruction_loss + self.alpha * clamped_bce_loss |
|
|
else: |
|
|
total_loss = reconstruction_loss + self.alpha * bce_end_token_loss |
|
|
else: |
|
|
reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True) |
|
|
bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True) |
|
|
total_loss = reconstruction_loss + torch.tensor(0.0, device=device, requires_grad=True) |
|
|
|
|
|
return { |
|
|
'loss': total_loss, |
|
|
'reconstruction_loss': reconstruction_loss, |
|
|
'bce_end_token_loss': bce_end_token_loss, |
|
|
} |
|
|
|
|
|
|
|
|
__all__ = ["TextSyncMimi"] |
|
|
|