Spaces:
Sleeping
Sleeping
# ============================================================================= | |
# routing/router.py | |
# ============================================================================= | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from typing import List, Dict, Tuple, Optional | |
from collections import defaultdict | |
import re | |
from utils.domain_configs import DomainConfigs | |
class TopicRouter(nn.Module): | |
def __init__(self, config, domain_configs: List[Dict]): | |
super().__init__() | |
self.config = config | |
self.domain_configs = domain_configs | |
self.num_specialists = len(domain_configs) | |
# Build keyword mappings | |
self.keyword_to_domains = defaultdict(list) | |
self.domain_keywords = {} | |
for domain in domain_configs: | |
domain_id = domain["id"] | |
keywords = domain["keywords"] | |
self.domain_keywords[domain_id] = keywords | |
for keyword in keywords: | |
self.keyword_to_domains[keyword.lower()].append(domain_id) | |
# Neural router for complex routing decisions | |
self.neural_router = nn.Sequential( | |
nn.Linear(config.d_model, 512), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(512, 256), | |
nn.ReLU(), | |
nn.Linear(256, self.num_specialists) | |
) | |
# Text similarity threshold | |
self.similarity_threshold = 0.1 | |
def keyword_based_routing(self, text: str) -> Dict[int, float]: | |
"""Route based on keyword matching""" | |
text_lower = text.lower() | |
domain_scores = defaultdict(float) | |
# Count keyword matches for each domain | |
for domain_id, keywords in self.domain_keywords.items(): | |
for keyword in keywords: | |
if keyword in text_lower: | |
# Weight by keyword frequency and length | |
count = text_lower.count(keyword) | |
weight = len(keyword) / 10.0 # Longer keywords get higher weight | |
domain_scores[domain_id] += count * weight | |
# Normalize scores | |
total_score = sum(domain_scores.values()) | |
if total_score > 0: | |
domain_scores = {k: v/total_score for k, v in domain_scores.items()} | |
return dict(domain_scores) | |
def neural_routing(self, embeddings: torch.Tensor) -> torch.Tensor: | |
"""Neural network based routing""" | |
# Use mean pooling of embeddings | |
pooled = embeddings.mean(dim=1) # [batch, d_model] | |
scores = self.neural_router(pooled) # [batch, num_specialists] | |
return torch.softmax(scores, dim=-1) | |
def route_text(self, text: str, embeddings: torch.Tensor = None, | |
max_specialists: int = 10) -> List[Tuple[int, float]]: | |
""" | |
Route text to appropriate specialists | |
Args: | |
text: Input text to route | |
embeddings: Text embeddings [1, seq_len, d_model] | |
max_specialists: Maximum number of specialists to activate | |
Returns: | |
List of (specialist_id, confidence) tuples | |
""" | |
# Keyword-based routing | |
keyword_scores = self.keyword_based_routing(text) | |
# Neural routing (if embeddings provided) | |
neural_scores = {} | |
if embeddings is not None: | |
neural_weights = self.neural_routing(embeddings) | |
neural_scores = {i: float(neural_weights[0, i]) | |
for i in range(self.num_specialists)} | |
# Combine scores | |
final_scores = {} | |
for i in range(self.num_specialists): | |
keyword_score = keyword_scores.get(i, 0.0) | |
neural_score = neural_scores.get(i, 0.0) | |
# Weighted combination | |
final_scores[i] = 0.7 * keyword_score + 0.3 * neural_score | |
# Sort by score and take top specialists | |
sorted_specialists = sorted(final_scores.items(), | |
key=lambda x: x[1], | |
reverse=True) | |
# Filter by threshold and limit | |
active_specialists = [] | |
for specialist_id, score in sorted_specialists: | |
if score > self.similarity_threshold and len(active_specialists) < max_specialists: | |
active_specialists.append((specialist_id, score)) | |
# Ensure at least one specialist is active | |
if not active_specialists and sorted_specialists: | |
active_specialists = [sorted_specialists[0]] | |
return active_specialists | |
def chunk_and_route(self, text: str, chunk_size: int = 512) -> List[Dict]: | |
""" | |
Split text into chunks and route each chunk | |
Returns: | |
List of dicts with 'text', 'specialists', 'chunk_id' | |
""" | |
# Simple sentence-based chunking | |
sentences = re.split(r'[.!?]+', text) | |
chunks = [] | |
current_chunk = "" | |
chunk_id = 0 | |
for sentence in sentences: | |
if len(current_chunk) + len(sentence) > chunk_size and current_chunk: | |
# Route current chunk | |
specialists = self.route_text(current_chunk) | |
chunks.append({ | |
'text': current_chunk.strip(), | |
'specialists': specialists, | |
'chunk_id': chunk_id | |
}) | |
current_chunk = sentence | |
chunk_id += 1 | |
else: | |
current_chunk += sentence + ". " | |
# Handle last chunk | |
if current_chunk.strip(): | |
specialists = self.route_text(current_chunk) | |
chunks.append({ | |
'text': current_chunk.strip(), | |
'specialists': specialists, | |
'chunk_id': chunk_id | |
}) | |
return chunks | |