Debito's picture
Upload 3 files
2ee6fe0 verified
# =============================================================================
# routing/aggregator.py
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple
from core.config import MambaConfig
class AttentionAggregator(nn.Module):
"""Attention-based aggregator for combining specialist outputs"""
def __init__(self, config: MambaConfig):
super().__init__()
self.config = config
self.d_model = config.d_model
self.num_specialists = config.num_specialists
# Attention mechanism for combining specialist outputs
self.specialist_attention = nn.MultiheadAttention(
embed_dim=self.d_model,
num_heads=8,
dropout=0.1,
batch_first=True
)
# Project specialist confidence scores
self.confidence_proj = nn.Linear(1, self.d_model)
# Output layers
self.output_layers = nn.Sequential(
nn.Linear(self.d_model, self.d_model * 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(self.d_model * 2, self.d_model),
nn.LayerNorm(self.d_model)
)
# Final language modeling head
self.lm_head = nn.Linear(self.d_model, config.vocab_size, bias=False)
def forward(self, specialist_outputs: Dict[int, List[Dict]]) -> torch.Tensor:
"""
Aggregate specialist outputs into final representation
Args:
specialist_outputs: Dict mapping chunk_id to list of specialist results
Returns:
aggregated_logits: [batch, seq_len, vocab_size]
"""
batch_outputs = []
for chunk_id in sorted(specialist_outputs.keys()):
chunk_results = specialist_outputs[chunk_id]
if not chunk_results:
continue
# Stack specialist encodings
encodings = []
confidences = []
for result in chunk_results:
if result is not None:
encodings.append(result['encoding'])
confidences.append(result['confidence'])
if not encodings:
continue
# Stack tensors
specialist_encodings = torch.stack(encodings) # [num_specialists, d_model]
confidence_scores = torch.tensor(confidences, device=encodings[0].device)
# Project confidence scores
confidence_embeddings = self.confidence_proj(
confidence_scores.unsqueeze(-1)
) # [num_specialists, d_model]
# Add confidence information to encodings
enhanced_encodings = specialist_encodings + confidence_embeddings
# Apply attention to combine specialist outputs
# Use self-attention to let specialists communicate
aggregated, _ = self.specialist_attention(
enhanced_encodings.unsqueeze(0), # [1, num_specialists, d_model]
enhanced_encodings.unsqueeze(0),
enhanced_encodings.unsqueeze(0)
)
# Pool the attended representations
chunk_representation = aggregated.mean(dim=1) # [1, d_model]
# Apply output layers
chunk_output = self.output_layers(chunk_representation)
batch_outputs.append(chunk_output)
if not batch_outputs:
# Return dummy output if no valid results
return torch.zeros(1, 1, self.config.vocab_size)
# Concatenate chunk outputs
final_representation = torch.cat(batch_outputs, dim=0) # [num_chunks, d_model]
# Generate logits
logits = self.lm_head(final_representation) # [num_chunks, vocab_size]
return logits.unsqueeze(0) # [1, num_chunks, vocab_size]
def generate_response(self, specialist_outputs: Dict[int, List[Dict]],
max_tokens: int = 100) -> str:
"""Generate text response from specialist outputs"""
# Get aggregated logits
logits = self.forward(specialist_outputs)
# Simple greedy decoding (can be improved with better generation)
generated_ids = []
current_logits = logits[0, -1, :] # Use last chunk's logits
for _ in range(max_tokens):
# Get next token
next_token = torch.argmax(current_logits, dim=-1)
generated_ids.append(next_token.item())
# Break on EOS token (assuming token 0 is EOS)
if next_token.item() == 0:
break
# Convert to text (placeholder - should use proper tokenizer)
# This is simplified - integrate with actual tokenizer for real text
response = f"Generated response with {len(generated_ids)} tokens"
return response