Spaces:
Sleeping
Sleeping
# ============================================================================= | |
# routing/aggregator.py | |
# ============================================================================= | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Dict, List, Tuple | |
from core.config import MambaConfig | |
class AttentionAggregator(nn.Module): | |
"""Attention-based aggregator for combining specialist outputs""" | |
def __init__(self, config: MambaConfig): | |
super().__init__() | |
self.config = config | |
self.d_model = config.d_model | |
self.num_specialists = config.num_specialists | |
# Attention mechanism for combining specialist outputs | |
self.specialist_attention = nn.MultiheadAttention( | |
embed_dim=self.d_model, | |
num_heads=8, | |
dropout=0.1, | |
batch_first=True | |
) | |
# Project specialist confidence scores | |
self.confidence_proj = nn.Linear(1, self.d_model) | |
# Output layers | |
self.output_layers = nn.Sequential( | |
nn.Linear(self.d_model, self.d_model * 2), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(self.d_model * 2, self.d_model), | |
nn.LayerNorm(self.d_model) | |
) | |
# Final language modeling head | |
self.lm_head = nn.Linear(self.d_model, config.vocab_size, bias=False) | |
def forward(self, specialist_outputs: Dict[int, List[Dict]]) -> torch.Tensor: | |
""" | |
Aggregate specialist outputs into final representation | |
Args: | |
specialist_outputs: Dict mapping chunk_id to list of specialist results | |
Returns: | |
aggregated_logits: [batch, seq_len, vocab_size] | |
""" | |
batch_outputs = [] | |
for chunk_id in sorted(specialist_outputs.keys()): | |
chunk_results = specialist_outputs[chunk_id] | |
if not chunk_results: | |
continue | |
# Stack specialist encodings | |
encodings = [] | |
confidences = [] | |
for result in chunk_results: | |
if result is not None: | |
encodings.append(result['encoding']) | |
confidences.append(result['confidence']) | |
if not encodings: | |
continue | |
# Stack tensors | |
specialist_encodings = torch.stack(encodings) # [num_specialists, d_model] | |
confidence_scores = torch.tensor(confidences, device=encodings[0].device) | |
# Project confidence scores | |
confidence_embeddings = self.confidence_proj( | |
confidence_scores.unsqueeze(-1) | |
) # [num_specialists, d_model] | |
# Add confidence information to encodings | |
enhanced_encodings = specialist_encodings + confidence_embeddings | |
# Apply attention to combine specialist outputs | |
# Use self-attention to let specialists communicate | |
aggregated, _ = self.specialist_attention( | |
enhanced_encodings.unsqueeze(0), # [1, num_specialists, d_model] | |
enhanced_encodings.unsqueeze(0), | |
enhanced_encodings.unsqueeze(0) | |
) | |
# Pool the attended representations | |
chunk_representation = aggregated.mean(dim=1) # [1, d_model] | |
# Apply output layers | |
chunk_output = self.output_layers(chunk_representation) | |
batch_outputs.append(chunk_output) | |
if not batch_outputs: | |
# Return dummy output if no valid results | |
return torch.zeros(1, 1, self.config.vocab_size) | |
# Concatenate chunk outputs | |
final_representation = torch.cat(batch_outputs, dim=0) # [num_chunks, d_model] | |
# Generate logits | |
logits = self.lm_head(final_representation) # [num_chunks, vocab_size] | |
return logits.unsqueeze(0) # [1, num_chunks, vocab_size] | |
def generate_response(self, specialist_outputs: Dict[int, List[Dict]], | |
max_tokens: int = 100) -> str: | |
"""Generate text response from specialist outputs""" | |
# Get aggregated logits | |
logits = self.forward(specialist_outputs) | |
# Simple greedy decoding (can be improved with better generation) | |
generated_ids = [] | |
current_logits = logits[0, -1, :] # Use last chunk's logits | |
for _ in range(max_tokens): | |
# Get next token | |
next_token = torch.argmax(current_logits, dim=-1) | |
generated_ids.append(next_token.item()) | |
# Break on EOS token (assuming token 0 is EOS) | |
if next_token.item() == 0: | |
break | |
# Convert to text (placeholder - should use proper tokenizer) | |
# This is simplified - integrate with actual tokenizer for real text | |
response = f"Generated response with {len(generated_ids)} tokens" | |
return response |