Spaces:
Sleeping
Sleeping
File size: 5,301 Bytes
2ee6fe0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# =============================================================================
# routing/aggregator.py
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple
from core.config import MambaConfig
class AttentionAggregator(nn.Module):
"""Attention-based aggregator for combining specialist outputs"""
def __init__(self, config: MambaConfig):
super().__init__()
self.config = config
self.d_model = config.d_model
self.num_specialists = config.num_specialists
# Attention mechanism for combining specialist outputs
self.specialist_attention = nn.MultiheadAttention(
embed_dim=self.d_model,
num_heads=8,
dropout=0.1,
batch_first=True
)
# Project specialist confidence scores
self.confidence_proj = nn.Linear(1, self.d_model)
# Output layers
self.output_layers = nn.Sequential(
nn.Linear(self.d_model, self.d_model * 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(self.d_model * 2, self.d_model),
nn.LayerNorm(self.d_model)
)
# Final language modeling head
self.lm_head = nn.Linear(self.d_model, config.vocab_size, bias=False)
def forward(self, specialist_outputs: Dict[int, List[Dict]]) -> torch.Tensor:
"""
Aggregate specialist outputs into final representation
Args:
specialist_outputs: Dict mapping chunk_id to list of specialist results
Returns:
aggregated_logits: [batch, seq_len, vocab_size]
"""
batch_outputs = []
for chunk_id in sorted(specialist_outputs.keys()):
chunk_results = specialist_outputs[chunk_id]
if not chunk_results:
continue
# Stack specialist encodings
encodings = []
confidences = []
for result in chunk_results:
if result is not None:
encodings.append(result['encoding'])
confidences.append(result['confidence'])
if not encodings:
continue
# Stack tensors
specialist_encodings = torch.stack(encodings) # [num_specialists, d_model]
confidence_scores = torch.tensor(confidences, device=encodings[0].device)
# Project confidence scores
confidence_embeddings = self.confidence_proj(
confidence_scores.unsqueeze(-1)
) # [num_specialists, d_model]
# Add confidence information to encodings
enhanced_encodings = specialist_encodings + confidence_embeddings
# Apply attention to combine specialist outputs
# Use self-attention to let specialists communicate
aggregated, _ = self.specialist_attention(
enhanced_encodings.unsqueeze(0), # [1, num_specialists, d_model]
enhanced_encodings.unsqueeze(0),
enhanced_encodings.unsqueeze(0)
)
# Pool the attended representations
chunk_representation = aggregated.mean(dim=1) # [1, d_model]
# Apply output layers
chunk_output = self.output_layers(chunk_representation)
batch_outputs.append(chunk_output)
if not batch_outputs:
# Return dummy output if no valid results
return torch.zeros(1, 1, self.config.vocab_size)
# Concatenate chunk outputs
final_representation = torch.cat(batch_outputs, dim=0) # [num_chunks, d_model]
# Generate logits
logits = self.lm_head(final_representation) # [num_chunks, vocab_size]
return logits.unsqueeze(0) # [1, num_chunks, vocab_size]
def generate_response(self, specialist_outputs: Dict[int, List[Dict]],
max_tokens: int = 100) -> str:
"""Generate text response from specialist outputs"""
# Get aggregated logits
logits = self.forward(specialist_outputs)
# Simple greedy decoding (can be improved with better generation)
generated_ids = []
current_logits = logits[0, -1, :] # Use last chunk's logits
for _ in range(max_tokens):
# Get next token
next_token = torch.argmax(current_logits, dim=-1)
generated_ids.append(next_token.item())
# Break on EOS token (assuming token 0 is EOS)
if next_token.item() == 0:
break
# Convert to text (placeholder - should use proper tokenizer)
# This is simplified - integrate with actual tokenizer for real text
response = f"Generated response with {len(generated_ids)} tokens"
return response |