TextSyncMimi-v1 / modeling_text_sync_mimi.py
potsawee's picture
Upload modeling_text_sync_mimi.py with huggingface_hub
94c52d0 verified
"""PyTorch TextSyncMimi model - Text-synchronous neural audio codec based on Mimi."""
import torch
import torch.nn as nn
from typing import Optional, Dict, List, Union
try:
from .configuration_mimi import MimiConfig
from .configuration_text_sync_mimi import TextSyncMimiConfig
from .modeling_mimi_clean import MimiPreTrainedModel, MimiModel
from .modeling_backbone_components import (
CrossAttentionTransformer,
CausalAttentionTransformer
)
except ImportError:
from configuration_mimi import MimiConfig
from configuration_text_sync_mimi import TextSyncMimiConfig
from modeling_mimi_clean import MimiPreTrainedModel, MimiModel
from modeling_backbone_components import (
CrossAttentionTransformer,
CausalAttentionTransformer
)
class TextSyncMimi(MimiPreTrainedModel):
"""
TextSyncMimi: Text-Synchronous Neural Audio Codec Model
A neural audio codec model that combines text and speech representations for
high-quality text-to-speech synthesis. Features:
- Learnable text embeddings
- Cross-attention transformer for text-speech alignment
- Autoregressive transformer for causal speech generation
- BCE-based end token prediction for dynamic duration control
Architecture:
- Text Embedding Layer: Maps token IDs to 4,096-dim embeddings
- Mimi Encoder: Pre-trained audio encoder (frozen)
- Text Projection: Linear projection from 4,096 to 512 dimensions
- Cross-Attention Transformer: Aligns text with speech features
- Autoregressive Transformer: Generates speech representations
- End Token Classifier: Predicts when to stop generating
"""
config_class = TextSyncMimiConfig
def __init__(
self,
config: Optional[Union[MimiConfig, 'TextSyncMimiConfig']] = None,
model_id: Optional[str] = None,
token: Optional[str] = None,
alpha: Optional[float] = None,
cross_attention_layers: Optional[int] = None,
causal_attention_layers: Optional[int] = None,
bce_threshold: Optional[float] = None,
vocab_size: Optional[int] = None,
):
"""
Initialize TextSyncMimi model.
Args:
config: Model configuration (TextSyncMimiConfig or MimiConfig)
model_id: Mimi model ID (e.g., "kyutai/mimi"). If None, uses config.mimi_model_id
token: Hugging Face authentication token
alpha: Weight for BCE end token loss. If None, uses config.alpha
cross_attention_layers: Number of cross-attention layers. If None, uses config
causal_attention_layers: Number of autoregressive layers. If None, uses config
bce_threshold: BCE loss threshold. If None, uses config.bce_threshold
vocab_size: Text vocabulary size. If None, uses config.vocab_size
"""
# Handle config initialization for both manual instantiation and from_pretrained
if config is None:
if model_id is None:
raise ValueError("Either config or model_id must be provided")
config = MimiConfig.from_pretrained(model_id, token=token)
super().__init__(config)
# Extract parameters from config if not explicitly provided
if hasattr(config, 'mimi_model_id'):
model_id = model_id or config.mimi_model_id
if model_id is None:
raise ValueError("model_id must be provided either as argument or in config.mimi_model_id")
alpha = alpha if alpha is not None else getattr(config, 'alpha', 1.0)
cross_attention_layers = cross_attention_layers if cross_attention_layers is not None else getattr(config, 'cross_attention_layers', 2)
causal_attention_layers = causal_attention_layers if causal_attention_layers is not None else getattr(config, 'causal_attention_layers', 2)
bce_threshold = bce_threshold if bce_threshold is not None else getattr(config, 'bce_threshold', 0.1)
vocab_size = vocab_size if vocab_size is not None else getattr(config, 'vocab_size', 128256)
# load the mimi backbone
self.config = config
model = MimiModel.from_pretrained(model_id, token=token)
# hyperparameters for auxiliary loss
self.alpha = alpha
self.bce_threshold = bce_threshold
# Learnable text token embedding
self.text_token_embedding = nn.Embedding(vocab_size, 4096)
# Text projection
self.text_proj = nn.Linear(4096, 512)
# Cross-attention transformer
cross_attention_config = MimiConfig(**self.config.__dict__)
cross_attention_config.num_hidden_layers = cross_attention_layers
cross_attention_config.hidden_size = 512
self.cross_attention_transformer = CrossAttentionTransformer(cross_attention_config)
# decoder part (v1)
# Auto-regressive decoder:
# <|text_speech_latent|> [t_i] [s_i] <|time_speech_start|> [z_(i,1)] [z_(i,2)] ... [z_(i,K)] <|time_speech_end|>
# masking (not computing loss for <|text_speech_latent|> [t_i] [s_i] <|time_speech_start|>
# t_i already mapped from 4096 (e.g., llama embedding) -> 512
# s_i already 512
# z is mimi's decoder-input which is also 512
causal_attention_config = MimiConfig(**self.config.__dict__)
causal_attention_config.num_hidden_layers = causal_attention_layers
causal_attention_config.hidden_size = 512
self.ar_transformer = CausalAttentionTransformer(causal_attention_config)
# embedding for special positions in the autoregressive decoder
self.text_speech_latent_embed = nn.Embedding(1, 512)
self.time_speech_start_embed = nn.Embedding(1, 512)
self.time_speech_end_embed = nn.Embedding(1, 512)
# Binary classification head for end token prediction
self.end_token_classifier = nn.Linear(512, 1)
self.post_init()
# Frozen Mimi components
self.encoder = model.encoder
self.encoder_transformer = model.encoder_transformer
self.quantizer = model.quantizer
self.downsample = model.downsample
self.upsample = model.upsample
# print the number of parameters for each sub network in Millions
self._print_subnetwork_parameter_counts()
def initialize_text_embeddings_from_weights(self, embedding_weight: torch.Tensor) -> None:
"""
Initialize text embeddings from a weight matrix.
Args:
embedding_weight: Weight matrix of shape (vocab_size, 4096)
"""
if embedding_weight.dim() != 2 or embedding_weight.size(1) != 4096:
raise ValueError("embedding_weight must have shape (vocab_size, 4096)")
if embedding_weight.size(0) != self.text_token_embedding.num_embeddings:
raise ValueError("Provided vocab_size does not match model's text_token_embedding")
with torch.no_grad():
self.text_token_embedding.weight.copy_(embedding_weight)
for p in self.text_token_embedding.parameters():
p.requires_grad = True
def initialize_text_embeddings_from_llama(self, llama_embeddings_module: torch.nn.Module) -> None:
"""
Initialize text embeddings from a LLaMA embedding module.
Args:
llama_embeddings_module: LLaMA embedding module with weight shape (vocab_size, 4096)
"""
if not hasattr(llama_embeddings_module, 'weight'):
raise ValueError("llama_embeddings_module must have a 'weight' attribute")
weight = llama_embeddings_module.weight.data
self.initialize_text_embeddings_from_weights(weight)
def _print_subnetwork_parameter_counts(self) -> None:
"""Print parameter counts for model subnetworks."""
print("=" * 70)
print("TextSyncMimi Parameter Counts")
print("=" * 70)
print(f"Encoder: {sum(p.numel() for p in self.encoder.parameters()) / 1e6:.2f}M")
print(f"Encoder Transformer: {sum(p.numel() for p in self.encoder_transformer.parameters()) / 1e6:.2f}M")
print(f"Cross-Attention Transformer: {sum(p.numel() for p in self.cross_attention_transformer.parameters()) / 1e6:.2f}M")
print(f"AR Transformer: {sum(p.numel() for p in self.ar_transformer.parameters()) / 1e6:.2f}M")
print(f"Quantizer: {sum(p.numel() for p in self.quantizer.parameters()) / 1e6:.2f}M")
print("=" * 70)
def encode_audio_to_representation(
self,
input_values: torch.Tensor,
audio_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Encode audio to speech representation.
Args:
input_values: Audio waveform (B, 1, audio_len)
audio_attention_mask: Attention mask (B, audio_len)
Returns:
Speech embeddings (B, 512, 12.5 * T)
"""
batch_size = input_values.shape[0]
device = input_values.device
# Encode through Mimi encoder pipeline
embeddings = self.encoder(input_values)
encoder_outputs = self.encoder_transformer(embeddings.transpose(1, 2))
embeddings = encoder_outputs[0].transpose(1, 2)
embeddings = self.downsample(embeddings)
# Apply attention mask if provided
if audio_attention_mask is not None:
speech_seq_len = embeddings.shape[-1]
speech_attention_mask = torch.zeros(batch_size, speech_seq_len, device=device, dtype=torch.bool)
for b in range(batch_size):
actual_audio_len = audio_attention_mask[b].sum().item()
actual_speech_len = int(actual_audio_len * 12.5 / 24000)
actual_speech_len = min(actual_speech_len, speech_seq_len)
if actual_speech_len > 0:
speech_attention_mask[b, :actual_speech_len] = True
speech_mask_expanded = speech_attention_mask.unsqueeze(1)
embeddings = embeddings * speech_mask_expanded.float()
return embeddings
def generate_autoregressive(
self,
text_token_ids: torch.LongTensor,
input_values: Optional[torch.Tensor] = None,
speech_embeddings: Optional[torch.Tensor] = None,
audio_attention_mask: Optional[torch.Tensor] = None,
speech_attention_mask: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.Tensor] = None,
max_z_tokens: int = 50,
end_token_threshold: float = 0.5,
device: Optional[torch.device] = None,
) -> List[List[torch.Tensor]]:
"""
Generate audio autoregressively.
Args:
text_token_ids: Text token IDs (B, L)
input_values: Audio input (B, 1, 24000 * T) - for normal mode
speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode
audio_attention_mask: Audio mask (B, audio_seq_len) - for normal mode
speech_attention_mask: Speech mask (B, speech_seq_len) - for cached mode
text_attention_mask: Text mask (B, text_seq_len)
max_z_tokens: Maximum z tokens per text position
end_token_threshold: Probability threshold for stopping
device: Device for computation
Returns:
List of z_tokens lists (one per batch item)
"""
if device is None:
device = text_token_ids.device
self.eval()
with torch.no_grad():
# Get speech embeddings for cross-attention context
if speech_embeddings is not None:
# Use pre-computed speech embeddings (cached mode)
# speech_embeddings should already be (B, T, 512)
pass # speech_embeddings is already provided
else:
# Compute speech embeddings from input_values (normal mode)
if input_values is None:
raise ValueError("Either input_values or speech_embeddings must be provided")
speech_embeddings = self.encode_audio_to_representation(
input_values,
audio_attention_mask=audio_attention_mask
)
speech_embeddings = speech_embeddings.transpose(1, 2) # (B, T, 512)
# Embed token ids then project to 512
text_embeddings_4096 = self.text_token_embedding(text_token_ids) # (B, L, 4096)
text_embeddings_proj = self.text_proj(text_embeddings_4096) # (B, L, 512)
# Apply cross attention (same as in forward)
# Create attention masks
formatted_text_attention_mask = None
formatted_speech_attention_mask = None
batch_size, text_seq_len = text_embeddings_proj.shape[:2]
if text_attention_mask is not None:
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype))
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len)
combined_mask = causal_mask * padding_mask
formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf'))
else:
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype))
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf'))
# Handle speech attention mask (use speech_attention_mask if available, otherwise audio_attention_mask)
if speech_attention_mask is not None:
# For cached data, speech_attention_mask is already in the right format
speech_seq_len = speech_embeddings.shape[1]
speech_mask = speech_attention_mask.bool()
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
elif audio_attention_mask is not None:
# For non-cached data, convert audio_attention_mask to speech_attention_mask
speech_seq_len = speech_embeddings.shape[1]
speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=device)
for b in range(batch_size):
audio_len = audio_attention_mask[b].sum().item()
speech_len = int(audio_len * 12.5 / 24000)
speech_len = min(speech_len, speech_seq_len)
speech_mask[b, :speech_len] = True
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
else:
formatted_speech_attention_mask = None
# Cross attention
cross_attention_outputs = self.cross_attention_transformer(
hidden_states=text_embeddings_proj,
encoder_hidden_states=speech_embeddings,
attention_mask=formatted_text_attention_mask,
encoder_attention_mask=formatted_speech_attention_mask,
alignment_chunk_sizes=None, # V1 learns alignment
)
cross_attention_outputs = cross_attention_outputs.last_hidden_state
# Get special embeddings
text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device))
time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device))
time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device))
generated_z_tokens = []
# Generate for each batch item
for b in range(batch_size):
# Get valid text length for this sample
if text_attention_mask is not None:
valid_text_len = text_attention_mask[b].sum().item()
else:
valid_text_len = text_embeddings_proj.shape[1]
# Start sequence with text_speech_latent for context
sequence = [text_speech_latent_emb] # (1, 512)
batch_z_tokens = [] # Store z_tokens for this batch item
# Generate for each text position
for i in range(valid_text_len):
# Add t_i and s_i
t_i = text_embeddings_proj[b, i:i+1] # (1, 512)
s_i = cross_attention_outputs[b, i:i+1] # (1, 512)
sequence.extend([t_i, s_i])
# Add time_speech_start
sequence.append(time_speech_start_emb)
# Generate z tokens autoregressively for this text position
z_count = 0
while z_count < max_z_tokens:
# Prepare current sequence for AR transformer
current_sequence = torch.cat(sequence, dim=0).unsqueeze(0) # (1, seq_len, 512)
# Create attention mask for current sequence
seq_len = current_sequence.shape[1]
ar_attention_mask = torch.ones(1, seq_len, dtype=torch.bool, device=device)
# Get prediction from AR transformer
ar_outputs = self.ar_transformer(
hidden_states=current_sequence,
attention_mask=ar_attention_mask,
)
# Get the last prediction
last_prediction = ar_outputs.last_hidden_state[0, -1:, :] # (1, 512)
# Check stopping condition using BCE classifier (v1.1)
end_token_logit = self.end_token_classifier(last_prediction).squeeze(-1) # (1,)
end_token_prob = torch.sigmoid(end_token_logit).item() # Convert to probability
# Stop if probability is high enough (>= threshold means stop)
if end_token_prob >= end_token_threshold:
# Stop generating z tokens
break
else:
# Add this prediction as next z token to both sequence (for context) and z_tokens (for output)
sequence.append(last_prediction)
batch_z_tokens.append(last_prediction.squeeze(0)) # Remove batch dimension for output
z_count += 1
# Add time_speech_end to sequence for context
sequence.append(time_speech_end_emb)
# Store z_tokens for this batch item
generated_z_tokens.append(batch_z_tokens)
return generated_z_tokens
def forward(
self,
text_token_ids: torch.LongTensor,
input_values: Optional[torch.Tensor] = None,
speech_embeddings: Optional[torch.Tensor] = None,
alignment_chunk_sizes: torch.Tensor = None,
audio_attention_mask: Optional[torch.Tensor] = None,
speech_attention_mask: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Forward pass for training.
Args:
text_token_ids: Text token IDs (B, L)
input_values: Audio input (B, 1, 24000 * T) - for normal mode
speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode
alignment_chunk_sizes: Alignment chunk sizes (B, L)
audio_attention_mask: Audio mask (B, audio_seq_len)
speech_attention_mask: Speech mask (B, speech_seq_len)
text_attention_mask: Text mask (B, text_seq_len)
Returns:
Dictionary with 'loss', 'reconstruction_loss', and 'bce_end_token_loss'
"""
# Get speech embeddings
if speech_embeddings is not None:
pass
elif input_values is not None:
# Normal mode: compute speech embeddings from input_values
speech_embeddings_raw = self.encode_audio_to_representation(
input_values,
audio_attention_mask
)
# speech_embeddings_raw.shape = (B, 512, 12.5*T)
# Transpose: [B, 512, 12.5*T] -> [B, 12.5*T, 512]
speech_embeddings = speech_embeddings_raw.transpose(1, 2)
else:
raise ValueError("Either input_values or speech_embeddings must be provided")
# Embed token ids and project to 512-dim
text_embeddings_4096 = self.text_token_embedding(text_token_ids) # (B, L, 4096)
text_embeddings = self.text_proj(text_embeddings_4096) # (B, L, 512)
# Create proper attention masks for cross-attention
formatted_text_attention_mask = None
formatted_speech_attention_mask = None
# Handle text attention mask (causal mask for decoder)
batch_size, text_seq_len = text_embeddings.shape[:2]
if text_attention_mask is not None:
# Create causal mask and apply padding mask
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype))
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
# Apply padding mask to causal mask
padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len)
combined_mask = causal_mask * padding_mask
# Convert to attention scores (-inf for masked positions)
formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf'))
else:
# Create causal mask for all positions (no padding mask)
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype))
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf'))
# Handle speech attention mask (encoder mask)
# Use speech_attention_mask if available (cached mode), otherwise audio_attention_mask (normal mode)
if speech_attention_mask is not None:
# Cached mode: speech_attention_mask is already in the right format
speech_seq_len = speech_embeddings.shape[1]
speech_mask = speech_attention_mask.bool()
# Convert to attention format: [batch_size, 1, 1, speech_seq_len]
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
elif audio_attention_mask is not None:
# Normal mode: convert audio mask to speech embedding mask
speech_seq_len = speech_embeddings.shape[1]
# Create speech attention mask based on actual lengths
speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=speech_embeddings.device)
for b in range(batch_size):
audio_len = audio_attention_mask[b].sum().item()
speech_len = int(audio_len * 12.5 / 24000)
speech_len = min(speech_len, speech_seq_len)
speech_mask[b, :speech_len] = True
# Convert to attention format: [batch_size, 1, 1, speech_seq_len]
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
else:
# No masking
formatted_speech_attention_mask = None
# Cross attention: text attends to speech (no alignment constraints in V1)
# hidden_states (decoder) = text, encoder_hidden_states = speech
cross_attention_outputs = self.cross_attention_transformer(
hidden_states=text_embeddings,
encoder_hidden_states=speech_embeddings,
attention_mask=formatted_text_attention_mask, # Causal mask for text (decoder)
encoder_attention_mask=formatted_speech_attention_mask, # Mask for speech (encoder)
alignment_chunk_sizes=None, # v1 doesn't use alignment_chunk_sizes -- the model should learn the alignment itself
)
cross_attention_outputs = cross_attention_outputs.last_hidden_state
# Auto-regressive decoder part
# Following v0.5 where the target is the dequantized Mimi decoder-input
# Compute target representation = Mimi decoder-input (quantized->dequantized at 12.5*seconds)
# 12.5*seconds => T
with torch.no_grad():
embeddings_bct = speech_embeddings.transpose(1, 2) # (B, 512, T)
codes_kbt = self.quantizer.encode(embeddings_bct) # [K, B, T]
codes_bkt = codes_kbt.transpose(0, 1) # [B, K, T]
decoder_input_emb = self.quantizer.decode(codes_bkt) # (B, 512, T)
target_representation = decoder_input_emb.transpose(1, 2) # (B, T, 512)
# Build the interleaved sequence for the autoregressive decoder
# as well as the mask for loss computation
# Get special embeddings (all are single embeddings)
device = text_embeddings.device
text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512)
time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512)
time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512)
batch_size = text_embeddings.shape[0]
interleaved_sequences = []
loss_masks = []
bce_labels_batch = [] # BCE labels: 0 for z tokens, 1 for time_speech_end_emb
bce_masks = [] # BCE mask: True for z tokens and time_speech_end_emb
sequence_lengths = [] # Track actual sequence lengths before padding
all_z_tokens = [] # Collect all valid z_tokens for separation loss
max_total_length = 0
for b in range(batch_size):
# Start with text_speech_latent embedding
sequence_parts = [text_speech_latent_emb] # List to collect sequence parts
loss_mask_parts = [False] # Don't compute loss on special tokens
bce_label_parts = [0] # BCE labels (dummy for text_speech_latent_emb)
bce_mask_parts = [False] # BCE mask (False for text_speech_latent_emb)
# Get valid text length for this batch item
if text_attention_mask is not None:
valid_text_len = text_attention_mask[b].sum().item()
else:
valid_text_len = text_embeddings.shape[1]
# Track current position in target_representation
speech_position = 0
# For each text token
for i in range(valid_text_len):
# Add t_i (text embedding)
t_i = text_embeddings[b, i:i+1] # (1, 512)
sequence_parts.append(t_i)
loss_mask_parts.append(False)
bce_label_parts.append(0) # Dummy label for t_i
bce_mask_parts.append(False) # No BCE loss for t_i
# Add s_i (cross attention output)
s_i = cross_attention_outputs[b, i:i+1] # (1, 512)
sequence_parts.append(s_i)
loss_mask_parts.append(False)
bce_label_parts.append(0) # Dummy label for s_i
bce_mask_parts.append(False) # No BCE loss for s_i
# Add time_speech_start
sequence_parts.append(time_speech_start_emb)
loss_mask_parts.append(False)
bce_label_parts.append(0) # Dummy label for time_speech_start
bce_mask_parts.append(False) # No BCE loss for time_speech_start
# Add z tokens for this chunk
chunk_size = alignment_chunk_sizes[b, i].item()
if chunk_size > 0: # Only add if chunk size is positive
end_position = speech_position + chunk_size
# Make sure we don't exceed target_representation length
end_position = min(end_position, target_representation.shape[1])
actual_chunk_size = end_position - speech_position
if actual_chunk_size > 0:
z_tokens = target_representation[b, speech_position:end_position] # (actual_chunk_size, 512)
sequence_parts.append(z_tokens)
loss_mask_parts.extend([True] * actual_chunk_size) # Compute loss on z tokens
bce_label_parts.extend([0] * actual_chunk_size) # Label 0 for z tokens
bce_mask_parts.extend([True] * actual_chunk_size) # Compute BCE loss on z tokens
# Collect z_tokens for separation loss computation
all_z_tokens.append(z_tokens)
speech_position = end_position
# Add time_speech_end
sequence_parts.append(time_speech_end_emb)
loss_mask_parts.append(False)
bce_label_parts.append(1)
bce_mask_parts.append(True)
# Concatenate all parts for this batch item
full_sequence = torch.cat(sequence_parts, dim=0) # (total_length, 512)
loss_mask = torch.tensor(loss_mask_parts, dtype=torch.bool, device=device)
bce_labels = torch.tensor(bce_label_parts, dtype=torch.float, device=device)
bce_mask = torch.tensor(bce_mask_parts, dtype=torch.bool, device=device)
interleaved_sequences.append(full_sequence)
loss_masks.append(loss_mask)
bce_labels_batch.append(bce_labels)
bce_masks.append(bce_mask)
sequence_lengths.append(full_sequence.shape[0]) # Track actual length before padding
max_total_length = max(max_total_length, full_sequence.shape[0])
# Pad sequences
padded_sequences = []
padded_loss_masks = []
padded_bce_labels = []
padded_bce_masks = []
for sequence, loss_mask, bce_labels, bce_mask in zip(interleaved_sequences, loss_masks, bce_labels_batch, bce_masks):
current_length = sequence.shape[0]
if current_length < max_total_length:
padding = torch.zeros(max_total_length - current_length, 512, device=device, dtype=sequence.dtype)
padded_sequence = torch.cat([sequence, padding], dim=0)
mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device)
padded_mask = torch.cat([loss_mask, mask_padding], dim=0)
bce_label_padding = torch.zeros(max_total_length - current_length, dtype=torch.float, device=device)
padded_bce_label = torch.cat([bce_labels, bce_label_padding], dim=0)
bce_mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device)
padded_bce_mask = torch.cat([bce_mask, bce_mask_padding], dim=0)
else:
padded_sequence = sequence
padded_mask = loss_mask
padded_bce_label = bce_labels
padded_bce_mask = bce_mask
padded_sequences.append(padded_sequence)
padded_loss_masks.append(padded_mask)
padded_bce_labels.append(padded_bce_label)
padded_bce_masks.append(padded_bce_mask)
# Stack into batch tensors
interleaved_batch = torch.stack(padded_sequences, dim=0) # (batch_size, max_total_length, 512)
loss_mask_batch = torch.stack(padded_loss_masks, dim=0) # (batch_size, max_total_length)
bce_labels_batch_tensor = torch.stack(padded_bce_labels, dim=0) # (batch_size, max_total_length)
bce_mask_batch = torch.stack(padded_bce_masks, dim=0) # (batch_size, max_total_length)
# Autoregressive prediction
if max_total_length > 1:
ar_input = interleaved_batch[:, :-1, :] # (batch_size, max_total_length-1, 512)
ar_targets = interleaved_batch[:, 1:, :] # (batch_size, max_total_length-1, 512)
ar_loss_mask = loss_mask_batch[:, 1:] # (batch_size, max_total_length-1) - shift mask left
ar_bce_labels = bce_labels_batch_tensor[:, 1:] # (batch_size, max_total_length-1) - shift labels left
ar_bce_mask = bce_mask_batch[:, 1:] # (batch_size, max_total_length-1) - shift mask left
# Create attention mask for autoregressive transformer
# We need to mask padded positions while maintaining causal property
ar_seq_len = ar_input.shape[1]
ar_attention_mask = torch.zeros(batch_size, ar_seq_len, dtype=torch.bool, device=device)
for b in range(batch_size):
valid_len = min(ar_seq_len, sequence_lengths[b] - 1)
if valid_len > 0:
ar_attention_mask[b, :valid_len] = True
ar_outputs = self.ar_transformer(
hidden_states=ar_input,
attention_mask=ar_attention_mask, # This will be combined with causal mask inside transformer
)
ar_predictions = ar_outputs.last_hidden_state # (batch_size, max_total_length-1, 512)
# Compute BCE predictions for end token classification
bce_logits = self.end_token_classifier(ar_predictions).squeeze(-1) # (batch_size, max_total_length-1)
# Compute L2 loss only where ar_loss_mask is True (z tokens)
if ar_loss_mask.any():
# Extract valid positions for loss computation
valid_predictions = ar_predictions[ar_loss_mask] # (num_valid_positions, 512)
valid_targets = ar_targets[ar_loss_mask] # (num_valid_positions, 512)
# Compute L2 loss (MSE)
reconstruction_loss = nn.functional.mse_loss(
valid_predictions,
valid_targets,
reduction='mean'
)
else:
# Fallback if no valid positions (shouldn't happen in practice)
reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True)
# Compute BCE loss for end token classification (v1.1)
if ar_bce_mask.any():
# Extract valid positions for BCE loss computation
valid_bce_logits = bce_logits[ar_bce_mask] # (num_valid_bce_positions,)
valid_bce_labels = ar_bce_labels[ar_bce_mask] # (num_valid_bce_positions,)
# Compute BCE loss
bce_end_token_loss = nn.functional.binary_cross_entropy_with_logits(
valid_bce_logits,
valid_bce_labels,
reduction='mean'
)
else:
# Fallback if no valid BCE positions
bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True)
if self.bce_threshold > 0.0:
clamped_bce_loss = torch.clamp(bce_end_token_loss - self.bce_threshold, min=0.0)
total_loss = reconstruction_loss + self.alpha * clamped_bce_loss
else:
total_loss = reconstruction_loss + self.alpha * bce_end_token_loss
else:
reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True)
bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True)
total_loss = reconstruction_loss + torch.tensor(0.0, device=device, requires_grad=True)
return {
'loss': total_loss,
'reconstruction_loss': reconstruction_loss,
'bce_end_token_loss': bce_end_token_loss,
}
__all__ = ["TextSyncMimi"]