File size: 6,109 Bytes
2ee6fe0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# =============================================================================
# routing/router.py
# =============================================================================
import torch
import torch.nn as nn
import numpy as np
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
import re
from utils.domain_configs import DomainConfigs

class TopicRouter(nn.Module):
    def __init__(self, config, domain_configs: List[Dict]):
        super().__init__()
        self.config = config
        self.domain_configs = domain_configs
        self.num_specialists = len(domain_configs)
        
        # Build keyword mappings
        self.keyword_to_domains = defaultdict(list)
        self.domain_keywords = {}
        
        for domain in domain_configs:
            domain_id = domain["id"]
            keywords = domain["keywords"]
            self.domain_keywords[domain_id] = keywords
            
            for keyword in keywords:
                self.keyword_to_domains[keyword.lower()].append(domain_id)
        
        # Neural router for complex routing decisions
        self.neural_router = nn.Sequential(
            nn.Linear(config.d_model, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_specialists)
        )
        
        # Text similarity threshold
        self.similarity_threshold = 0.1
        
    def keyword_based_routing(self, text: str) -> Dict[int, float]:
        """Route based on keyword matching"""
        text_lower = text.lower()
        domain_scores = defaultdict(float)
        
        # Count keyword matches for each domain
        for domain_id, keywords in self.domain_keywords.items():
            for keyword in keywords:
                if keyword in text_lower:
                    # Weight by keyword frequency and length
                    count = text_lower.count(keyword)
                    weight = len(keyword) / 10.0  # Longer keywords get higher weight
                    domain_scores[domain_id] += count * weight
        
        # Normalize scores
        total_score = sum(domain_scores.values())
        if total_score > 0:
            domain_scores = {k: v/total_score for k, v in domain_scores.items()}
        
        return dict(domain_scores)
    
    def neural_routing(self, embeddings: torch.Tensor) -> torch.Tensor:
        """Neural network based routing"""
        # Use mean pooling of embeddings
        pooled = embeddings.mean(dim=1)  # [batch, d_model]
        scores = self.neural_router(pooled)  # [batch, num_specialists]
        return torch.softmax(scores, dim=-1)
    
    def route_text(self, text: str, embeddings: torch.Tensor = None, 

                   max_specialists: int = 10) -> List[Tuple[int, float]]:
        """

        Route text to appropriate specialists

        

        Args:

            text: Input text to route

            embeddings: Text embeddings [1, seq_len, d_model]

            max_specialists: Maximum number of specialists to activate

            

        Returns:

            List of (specialist_id, confidence) tuples

        """
        # Keyword-based routing
        keyword_scores = self.keyword_based_routing(text)
        
        # Neural routing (if embeddings provided)
        neural_scores = {}
        if embeddings is not None:
            neural_weights = self.neural_routing(embeddings)
            neural_scores = {i: float(neural_weights[0, i]) 
                           for i in range(self.num_specialists)}
        
        # Combine scores
        final_scores = {}
        for i in range(self.num_specialists):
            keyword_score = keyword_scores.get(i, 0.0)
            neural_score = neural_scores.get(i, 0.0)
            
            # Weighted combination
            final_scores[i] = 0.7 * keyword_score + 0.3 * neural_score
        
        # Sort by score and take top specialists
        sorted_specialists = sorted(final_scores.items(), 
                                  key=lambda x: x[1], 
                                  reverse=True)
        
        # Filter by threshold and limit
        active_specialists = []
        for specialist_id, score in sorted_specialists:
            if score > self.similarity_threshold and len(active_specialists) < max_specialists:
                active_specialists.append((specialist_id, score))
        
        # Ensure at least one specialist is active
        if not active_specialists and sorted_specialists:
            active_specialists = [sorted_specialists[0]]
        
        return active_specialists
    
    def chunk_and_route(self, text: str, chunk_size: int = 512) -> List[Dict]:
        """

        Split text into chunks and route each chunk

        

        Returns:

            List of dicts with 'text', 'specialists', 'chunk_id'

        """
        # Simple sentence-based chunking
        sentences = re.split(r'[.!?]+', text)
        chunks = []
        current_chunk = ""
        chunk_id = 0
        
        for sentence in sentences:
            if len(current_chunk) + len(sentence) > chunk_size and current_chunk:
                # Route current chunk
                specialists = self.route_text(current_chunk)
                chunks.append({
                    'text': current_chunk.strip(),
                    'specialists': specialists,
                    'chunk_id': chunk_id
                })
                current_chunk = sentence
                chunk_id += 1
            else:
                current_chunk += sentence + ". "
        
        # Handle last chunk
        if current_chunk.strip():
            specialists = self.route_text(current_chunk)
            chunks.append({
                'text': current_chunk.strip(),
                'specialists': specialists,
                'chunk_id': chunk_id
            })
        
        return chunks