File size: 5,301 Bytes
2ee6fe0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# =============================================================================
# routing/aggregator.py
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple
from core.config import MambaConfig

class AttentionAggregator(nn.Module):
    """Attention-based aggregator for combining specialist outputs"""
    
    def __init__(self, config: MambaConfig):
        super().__init__()
        self.config = config
        self.d_model = config.d_model
        self.num_specialists = config.num_specialists
        
        # Attention mechanism for combining specialist outputs
        self.specialist_attention = nn.MultiheadAttention(
            embed_dim=self.d_model,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
        # Project specialist confidence scores
        self.confidence_proj = nn.Linear(1, self.d_model)
        
        # Output layers
        self.output_layers = nn.Sequential(
            nn.Linear(self.d_model, self.d_model * 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(self.d_model * 2, self.d_model),
            nn.LayerNorm(self.d_model)
        )
        
        # Final language modeling head
        self.lm_head = nn.Linear(self.d_model, config.vocab_size, bias=False)
        
    def forward(self, specialist_outputs: Dict[int, List[Dict]]) -> torch.Tensor:
        """

        Aggregate specialist outputs into final representation

        

        Args:

            specialist_outputs: Dict mapping chunk_id to list of specialist results

            

        Returns:

            aggregated_logits: [batch, seq_len, vocab_size]

        """
        batch_outputs = []
        
        for chunk_id in sorted(specialist_outputs.keys()):
            chunk_results = specialist_outputs[chunk_id]
            
            if not chunk_results:
                continue
            
            # Stack specialist encodings
            encodings = []
            confidences = []
            
            for result in chunk_results:
                if result is not None:
                    encodings.append(result['encoding'])
                    confidences.append(result['confidence'])
            
            if not encodings:
                continue
            
            # Stack tensors
            specialist_encodings = torch.stack(encodings)  # [num_specialists, d_model]
            confidence_scores = torch.tensor(confidences, device=encodings[0].device)
            
            # Project confidence scores
            confidence_embeddings = self.confidence_proj(
                confidence_scores.unsqueeze(-1)
            )  # [num_specialists, d_model]
            
            # Add confidence information to encodings
            enhanced_encodings = specialist_encodings + confidence_embeddings
            
            # Apply attention to combine specialist outputs
            # Use self-attention to let specialists communicate
            aggregated, _ = self.specialist_attention(
                enhanced_encodings.unsqueeze(0),  # [1, num_specialists, d_model]
                enhanced_encodings.unsqueeze(0),
                enhanced_encodings.unsqueeze(0)
            )
            
            # Pool the attended representations
            chunk_representation = aggregated.mean(dim=1)  # [1, d_model]
            
            # Apply output layers
            chunk_output = self.output_layers(chunk_representation)
            batch_outputs.append(chunk_output)
        
        if not batch_outputs:
            # Return dummy output if no valid results
            return torch.zeros(1, 1, self.config.vocab_size)
        
        # Concatenate chunk outputs
        final_representation = torch.cat(batch_outputs, dim=0)  # [num_chunks, d_model]
        
        # Generate logits
        logits = self.lm_head(final_representation)  # [num_chunks, vocab_size]
        
        return logits.unsqueeze(0)  # [1, num_chunks, vocab_size]
    
    def generate_response(self, specialist_outputs: Dict[int, List[Dict]], 

                         max_tokens: int = 100) -> str:
        """Generate text response from specialist outputs"""
        # Get aggregated logits
        logits = self.forward(specialist_outputs)
        
        # Simple greedy decoding (can be improved with better generation)
        generated_ids = []
        current_logits = logits[0, -1, :]  # Use last chunk's logits
        
        for _ in range(max_tokens):
            # Get next token
            next_token = torch.argmax(current_logits, dim=-1)
            generated_ids.append(next_token.item())
            
            # Break on EOS token (assuming token 0 is EOS)
            if next_token.item() == 0:
                break
        
        # Convert to text (placeholder - should use proper tokenizer)
        # This is simplified - integrate with actual tokenizer for real text
        response = f"Generated response with {len(generated_ids)} tokens"
        
        return response