Spaces:
Sleeping
Sleeping
File size: 6,109 Bytes
2ee6fe0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# =============================================================================
# routing/router.py
# =============================================================================
import torch
import torch.nn as nn
import numpy as np
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
import re
from utils.domain_configs import DomainConfigs
class TopicRouter(nn.Module):
def __init__(self, config, domain_configs: List[Dict]):
super().__init__()
self.config = config
self.domain_configs = domain_configs
self.num_specialists = len(domain_configs)
# Build keyword mappings
self.keyword_to_domains = defaultdict(list)
self.domain_keywords = {}
for domain in domain_configs:
domain_id = domain["id"]
keywords = domain["keywords"]
self.domain_keywords[domain_id] = keywords
for keyword in keywords:
self.keyword_to_domains[keyword.lower()].append(domain_id)
# Neural router for complex routing decisions
self.neural_router = nn.Sequential(
nn.Linear(config.d_model, 512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, self.num_specialists)
)
# Text similarity threshold
self.similarity_threshold = 0.1
def keyword_based_routing(self, text: str) -> Dict[int, float]:
"""Route based on keyword matching"""
text_lower = text.lower()
domain_scores = defaultdict(float)
# Count keyword matches for each domain
for domain_id, keywords in self.domain_keywords.items():
for keyword in keywords:
if keyword in text_lower:
# Weight by keyword frequency and length
count = text_lower.count(keyword)
weight = len(keyword) / 10.0 # Longer keywords get higher weight
domain_scores[domain_id] += count * weight
# Normalize scores
total_score = sum(domain_scores.values())
if total_score > 0:
domain_scores = {k: v/total_score for k, v in domain_scores.items()}
return dict(domain_scores)
def neural_routing(self, embeddings: torch.Tensor) -> torch.Tensor:
"""Neural network based routing"""
# Use mean pooling of embeddings
pooled = embeddings.mean(dim=1) # [batch, d_model]
scores = self.neural_router(pooled) # [batch, num_specialists]
return torch.softmax(scores, dim=-1)
def route_text(self, text: str, embeddings: torch.Tensor = None,
max_specialists: int = 10) -> List[Tuple[int, float]]:
"""
Route text to appropriate specialists
Args:
text: Input text to route
embeddings: Text embeddings [1, seq_len, d_model]
max_specialists: Maximum number of specialists to activate
Returns:
List of (specialist_id, confidence) tuples
"""
# Keyword-based routing
keyword_scores = self.keyword_based_routing(text)
# Neural routing (if embeddings provided)
neural_scores = {}
if embeddings is not None:
neural_weights = self.neural_routing(embeddings)
neural_scores = {i: float(neural_weights[0, i])
for i in range(self.num_specialists)}
# Combine scores
final_scores = {}
for i in range(self.num_specialists):
keyword_score = keyword_scores.get(i, 0.0)
neural_score = neural_scores.get(i, 0.0)
# Weighted combination
final_scores[i] = 0.7 * keyword_score + 0.3 * neural_score
# Sort by score and take top specialists
sorted_specialists = sorted(final_scores.items(),
key=lambda x: x[1],
reverse=True)
# Filter by threshold and limit
active_specialists = []
for specialist_id, score in sorted_specialists:
if score > self.similarity_threshold and len(active_specialists) < max_specialists:
active_specialists.append((specialist_id, score))
# Ensure at least one specialist is active
if not active_specialists and sorted_specialists:
active_specialists = [sorted_specialists[0]]
return active_specialists
def chunk_and_route(self, text: str, chunk_size: int = 512) -> List[Dict]:
"""
Split text into chunks and route each chunk
Returns:
List of dicts with 'text', 'specialists', 'chunk_id'
"""
# Simple sentence-based chunking
sentences = re.split(r'[.!?]+', text)
chunks = []
current_chunk = ""
chunk_id = 0
for sentence in sentences:
if len(current_chunk) + len(sentence) > chunk_size and current_chunk:
# Route current chunk
specialists = self.route_text(current_chunk)
chunks.append({
'text': current_chunk.strip(),
'specialists': specialists,
'chunk_id': chunk_id
})
current_chunk = sentence
chunk_id += 1
else:
current_chunk += sentence + ". "
# Handle last chunk
if current_chunk.strip():
specialists = self.route_text(current_chunk)
chunks.append({
'text': current_chunk.strip(),
'specialists': specialists,
'chunk_id': chunk_id
})
return chunks
|