# ============================================================================= # routing/router.py # ============================================================================= import torch import torch.nn as nn import numpy as np from typing import List, Dict, Tuple, Optional from collections import defaultdict import re from utils.domain_configs import DomainConfigs class TopicRouter(nn.Module): def __init__(self, config, domain_configs: List[Dict]): super().__init__() self.config = config self.domain_configs = domain_configs self.num_specialists = len(domain_configs) # Build keyword mappings self.keyword_to_domains = defaultdict(list) self.domain_keywords = {} for domain in domain_configs: domain_id = domain["id"] keywords = domain["keywords"] self.domain_keywords[domain_id] = keywords for keyword in keywords: self.keyword_to_domains[keyword.lower()].append(domain_id) # Neural router for complex routing decisions self.neural_router = nn.Sequential( nn.Linear(config.d_model, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, self.num_specialists) ) # Text similarity threshold self.similarity_threshold = 0.1 def keyword_based_routing(self, text: str) -> Dict[int, float]: """Route based on keyword matching""" text_lower = text.lower() domain_scores = defaultdict(float) # Count keyword matches for each domain for domain_id, keywords in self.domain_keywords.items(): for keyword in keywords: if keyword in text_lower: # Weight by keyword frequency and length count = text_lower.count(keyword) weight = len(keyword) / 10.0 # Longer keywords get higher weight domain_scores[domain_id] += count * weight # Normalize scores total_score = sum(domain_scores.values()) if total_score > 0: domain_scores = {k: v/total_score for k, v in domain_scores.items()} return dict(domain_scores) def neural_routing(self, embeddings: torch.Tensor) -> torch.Tensor: """Neural network based routing""" # Use mean pooling of embeddings pooled = embeddings.mean(dim=1) # [batch, d_model] scores = self.neural_router(pooled) # [batch, num_specialists] return torch.softmax(scores, dim=-1) def route_text(self, text: str, embeddings: torch.Tensor = None, max_specialists: int = 10) -> List[Tuple[int, float]]: """ Route text to appropriate specialists Args: text: Input text to route embeddings: Text embeddings [1, seq_len, d_model] max_specialists: Maximum number of specialists to activate Returns: List of (specialist_id, confidence) tuples """ # Keyword-based routing keyword_scores = self.keyword_based_routing(text) # Neural routing (if embeddings provided) neural_scores = {} if embeddings is not None: neural_weights = self.neural_routing(embeddings) neural_scores = {i: float(neural_weights[0, i]) for i in range(self.num_specialists)} # Combine scores final_scores = {} for i in range(self.num_specialists): keyword_score = keyword_scores.get(i, 0.0) neural_score = neural_scores.get(i, 0.0) # Weighted combination final_scores[i] = 0.7 * keyword_score + 0.3 * neural_score # Sort by score and take top specialists sorted_specialists = sorted(final_scores.items(), key=lambda x: x[1], reverse=True) # Filter by threshold and limit active_specialists = [] for specialist_id, score in sorted_specialists: if score > self.similarity_threshold and len(active_specialists) < max_specialists: active_specialists.append((specialist_id, score)) # Ensure at least one specialist is active if not active_specialists and sorted_specialists: active_specialists = [sorted_specialists[0]] return active_specialists def chunk_and_route(self, text: str, chunk_size: int = 512) -> List[Dict]: """ Split text into chunks and route each chunk Returns: List of dicts with 'text', 'specialists', 'chunk_id' """ # Simple sentence-based chunking sentences = re.split(r'[.!?]+', text) chunks = [] current_chunk = "" chunk_id = 0 for sentence in sentences: if len(current_chunk) + len(sentence) > chunk_size and current_chunk: # Route current chunk specialists = self.route_text(current_chunk) chunks.append({ 'text': current_chunk.strip(), 'specialists': specialists, 'chunk_id': chunk_id }) current_chunk = sentence chunk_id += 1 else: current_chunk += sentence + ". " # Handle last chunk if current_chunk.strip(): specialists = self.route_text(current_chunk) chunks.append({ 'text': current_chunk.strip(), 'specialists': specialists, 'chunk_id': chunk_id }) return chunks