Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
"""
|
3 |
-
|
4 |
-
|
5 |
"""
|
6 |
|
7 |
import gradio as gr
|
@@ -14,7 +14,8 @@ import os
|
|
14 |
import psutil
|
15 |
from typing import Optional, Dict, Any, Tuple
|
16 |
from datetime import datetime
|
17 |
-
from transformers import AutoTokenizer, AutoConfig
|
|
|
18 |
|
19 |
# Setup comprehensive logging
|
20 |
logging.basicConfig(
|
@@ -27,8 +28,106 @@ logging.basicConfig(
|
|
27 |
)
|
28 |
logger = logging.getLogger(__name__)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
class MambaSwarmDemo:
|
31 |
-
"""Production-ready Mamba Swarm Demo with
|
32 |
|
33 |
def __init__(self, model_path: str = "./", fallback_mode: bool = False):
|
34 |
self.model = None
|
@@ -38,6 +137,8 @@ class MambaSwarmDemo:
|
|
38 |
self.model_path = model_path
|
39 |
self.fallback_mode = fallback_mode
|
40 |
self.model_loaded = False
|
|
|
|
|
41 |
|
42 |
# Performance tracking
|
43 |
self.stats = {
|
@@ -60,24 +161,23 @@ class MambaSwarmDemo:
|
|
60 |
}
|
61 |
|
62 |
self._initialize_model()
|
63 |
-
logger.info(f"Demo initialized - Model loaded: {self.model_loaded}, Fallback mode: {self.fallback_mode}")
|
64 |
|
65 |
def _initialize_model(self):
|
66 |
-
"""Initialize model with
|
67 |
try:
|
68 |
-
logger.info("Attempting to load
|
69 |
|
70 |
-
#
|
71 |
-
|
72 |
-
if not os.path.exists(config_path) and not self.fallback_mode:
|
73 |
-
logger.warning(f"Config file not found at {config_path}, enabling fallback mode")
|
74 |
-
self.fallback_mode = True
|
75 |
|
76 |
-
if not
|
77 |
-
|
78 |
-
self.
|
79 |
-
|
80 |
-
|
|
|
|
|
81 |
self._initialize_fallback_mode()
|
82 |
|
83 |
except Exception as e:
|
@@ -86,132 +186,136 @@ class MambaSwarmDemo:
|
|
86 |
self.fallback_mode = True
|
87 |
self._initialize_fallback_mode()
|
88 |
|
89 |
-
def
|
90 |
-
"""Load
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
try:
|
92 |
-
|
|
|
|
|
93 |
model_class = None
|
94 |
|
95 |
-
# Try importing from different locations
|
96 |
try:
|
97 |
from modeling_mamba_swarm import MambaSwarmForCausalLM
|
98 |
model_class = MambaSwarmForCausalLM
|
99 |
-
logger.info("
|
100 |
except ImportError:
|
101 |
try:
|
102 |
-
from
|
103 |
-
model_class =
|
104 |
-
logger.info("
|
105 |
except ImportError:
|
106 |
try:
|
107 |
-
from
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
except ImportError:
|
111 |
-
|
112 |
-
|
113 |
-
# Use the unified swarm in native mode
|
114 |
-
swarm = UnifiedMambaSwarm(use_pretrained=False)
|
115 |
-
if hasattr(swarm, 'native_swarm_model') and swarm.native_swarm_model:
|
116 |
-
self.model = swarm.native_swarm_model
|
117 |
-
self.model_loaded = True
|
118 |
-
logger.info("Loaded native swarm model from UnifiedMambaSwarm")
|
119 |
-
return
|
120 |
-
else:
|
121 |
-
raise ImportError("No native swarm model available")
|
122 |
-
except ImportError as e:
|
123 |
-
logger.error(f"All model imports failed: {e}")
|
124 |
-
raise ImportError("No compatible Mamba Swarm model found")
|
125 |
|
126 |
if model_class is None:
|
127 |
-
|
128 |
|
129 |
-
#
|
130 |
try:
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
136 |
try:
|
137 |
-
from modeling_mamba_swarm import MambaSwarmConfig
|
138 |
-
self.config = MambaSwarmConfig(
|
139 |
-
num_encoders=8,
|
140 |
-
max_mamba_encoders=100,
|
141 |
-
d_model=768,
|
142 |
-
vocab_size=50257,
|
143 |
-
max_sequence_length=2048
|
144 |
-
)
|
145 |
-
logger.info("Using default MambaSwarmConfig")
|
146 |
-
except ImportError:
|
147 |
-
# Final fallback to basic config
|
148 |
from core.config import MambaConfig
|
149 |
self.config = MambaConfig()
|
150 |
-
# Add swarm-specific attributes
|
151 |
self.config.num_encoders = 8
|
152 |
self.config.max_mamba_encoders = 100
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
-
#
|
157 |
-
|
158 |
-
self.
|
159 |
-
if self.tokenizer.pad_token is None:
|
160 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
161 |
-
logger.info("Tokenizer loaded successfully")
|
162 |
-
except Exception as e:
|
163 |
-
logger.warning(f"Could not load tokenizer: {e}")
|
164 |
-
# Use a simple fallback tokenizer
|
165 |
-
from transformers import GPT2Tokenizer
|
166 |
-
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
167 |
-
if self.tokenizer.pad_token is None:
|
168 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
169 |
-
logger.info("Using fallback GPT2 tokenizer")
|
170 |
-
|
171 |
-
# Load model with memory optimization
|
172 |
-
dtype = torch.float16 if self.device.type == "cuda" else torch.float32
|
173 |
-
|
174 |
-
if model_class == MambaEncoderSwarmModel:
|
175 |
-
# Native integration model - create with MambaConfig
|
176 |
-
from core.config import MambaConfig
|
177 |
-
if not hasattr(self, 'config') or not isinstance(self.config, MambaConfig):
|
178 |
-
mamba_config = MambaConfig(
|
179 |
-
d_model=getattr(self.config, 'd_model', 768),
|
180 |
-
vocab_size=getattr(self.config, 'vocab_size', 50257),
|
181 |
-
n_layers=8,
|
182 |
-
d_state=16,
|
183 |
-
d_conv=4,
|
184 |
-
bias=False
|
185 |
-
)
|
186 |
-
self.model = model_class(mamba_config, num_encoders=getattr(self.config, 'num_encoders', 8))
|
187 |
-
else:
|
188 |
-
self.model = model_class(self.config, num_encoders=getattr(self.config, 'num_encoders', 8))
|
189 |
else:
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
low_cpu_mem_usage=True
|
198 |
-
)
|
199 |
-
else:
|
200 |
-
# Create with config only
|
201 |
-
self.model = model_class(self.config)
|
202 |
|
203 |
self.model.to(self.device)
|
204 |
self.model.eval()
|
205 |
self.model_loaded = True
|
206 |
|
207 |
-
|
208 |
-
|
209 |
-
logger.info(f"Model loaded successfully on {self.device}")
|
210 |
-
logger.info(f"Model parameters: {num_params:,} ({num_params/1e6:.1f}M)")
|
211 |
|
212 |
except Exception as e:
|
213 |
-
logger.error(f"
|
214 |
-
|
215 |
|
216 |
def _initialize_fallback_mode(self):
|
217 |
"""Initialize fallback/simulation mode"""
|
@@ -246,7 +350,6 @@ class MambaSwarmDemo:
|
|
246 |
self.eos_token = "[EOS]"
|
247 |
|
248 |
def encode(self, text, return_tensors=None):
|
249 |
-
# Simple word-based tokenization for simulation
|
250 |
tokens = text.split()
|
251 |
token_ids = [hash(token) % 1000 for token in tokens]
|
252 |
if return_tensors == "pt":
|
@@ -254,7 +357,6 @@ class MambaSwarmDemo:
|
|
254 |
return token_ids
|
255 |
|
256 |
def decode(self, token_ids, skip_special_tokens=True):
|
257 |
-
# Mock decoding
|
258 |
return f"Generated response for {len(token_ids)} tokens"
|
259 |
|
260 |
self.tokenizer = MockTokenizer()
|
@@ -310,7 +412,7 @@ class MambaSwarmDemo:
|
|
310 |
available_encoders = list(range(start, min(end + 1, 101)))
|
311 |
|
312 |
# Select encoders based on prompt complexity and domain
|
313 |
-
prompt_complexity = min(len(prompt.split()) / 10, 3.0)
|
314 |
optimal_count = min(max(int(num_encoders * (1 + prompt_complexity)), 3), 25)
|
315 |
|
316 |
if len(available_encoders) >= optimal_count:
|
@@ -333,171 +435,9 @@ class MambaSwarmDemo:
|
|
333 |
'total_active': len(selected_encoders)
|
334 |
}
|
335 |
|
336 |
-
def _simulate_generation(self, prompt: str, routing_info: Dict, max_length: int) -> str:
|
337 |
-
"""Generate sophisticated simulated responses based on domain"""
|
338 |
-
domain = routing_info['detected_domain']
|
339 |
-
|
340 |
-
domain_responses = {
|
341 |
-
'medical': f"""Based on medical literature and current research, regarding "{prompt[:50]}...":
|
342 |
-
|
343 |
-
This condition/topic involves multiple factors including genetic predisposition, environmental influences, and lifestyle factors. Key considerations include:
|
344 |
-
|
345 |
-
β’ Proper medical evaluation is essential
|
346 |
-
β’ Individual symptoms may vary significantly
|
347 |
-
β’ Treatment approaches should be personalized
|
348 |
-
β’ Regular monitoring is typically recommended
|
349 |
-
|
350 |
-
**Important**: This information is for educational purposes only. Please consult with qualified healthcare professionals for personalized medical advice and treatment recommendations.""",
|
351 |
-
|
352 |
-
'legal': f"""From a legal perspective on "{prompt[:50]}...":
|
353 |
-
|
354 |
-
The legal framework surrounding this matter involves several key considerations:
|
355 |
-
|
356 |
-
β’ Jurisdictional requirements and applicable statutes
|
357 |
-
β’ Precedent cases and regulatory guidelines
|
358 |
-
β’ Compliance obligations and reporting requirements
|
359 |
-
β’ Risk assessment and mitigation strategies
|
360 |
-
|
361 |
-
**Disclaimer**: This information is for general informational purposes only and does not constitute legal advice. Consult with qualified legal professionals for specific legal matters.""",
|
362 |
-
|
363 |
-
'code': f"""Here's a comprehensive solution for "{prompt[:50]}...":
|
364 |
-
|
365 |
-
```python
|
366 |
-
def optimized_solution(input_data):
|
367 |
-
\"\"\"
|
368 |
-
Efficient implementation with error handling
|
369 |
-
Time complexity: O(n log n)
|
370 |
-
Space complexity: O(n)
|
371 |
-
\"\"\"
|
372 |
-
try:
|
373 |
-
# Input validation
|
374 |
-
if not input_data:
|
375 |
-
raise ValueError("Input data cannot be empty")
|
376 |
-
|
377 |
-
# Core algorithm implementation
|
378 |
-
result = process_data(input_data)
|
379 |
-
|
380 |
-
# Additional optimizations
|
381 |
-
result = optimize_output(result)
|
382 |
-
|
383 |
-
return result
|
384 |
-
|
385 |
-
except Exception as e:
|
386 |
-
logger.error(f"Processing error: {{e}}")
|
387 |
-
return None
|
388 |
-
|
389 |
-
def process_data(data):
|
390 |
-
# Implementation details here
|
391 |
-
return processed_data
|
392 |
-
|
393 |
-
def optimize_output(data):
|
394 |
-
# Performance optimizations
|
395 |
-
return optimized_data
|
396 |
-
```
|
397 |
-
|
398 |
-
**Key Features:**
|
399 |
-
β’ Error handling and input validation
|
400 |
-
β’ Optimized performance characteristics
|
401 |
-
β’ Comprehensive documentation
|
402 |
-
β’ Production-ready implementation""",
|
403 |
-
|
404 |
-
'science': f"""Scientific Analysis of "{prompt[:50]}...":
|
405 |
-
|
406 |
-
Based on current scientific understanding and peer-reviewed research:
|
407 |
-
|
408 |
-
**Theoretical Framework:**
|
409 |
-
The underlying principles involve complex interactions between multiple variables, governed by established scientific laws and emerging theories.
|
410 |
-
|
411 |
-
**Methodology:**
|
412 |
-
β’ Systematic observation and data collection
|
413 |
-
β’ Controlled experimental design
|
414 |
-
β’ Statistical analysis and validation
|
415 |
-
β’ Peer review and reproducibility testing
|
416 |
-
|
417 |
-
**Current Research:**
|
418 |
-
Recent studies indicate significant progress in understanding the mechanisms involved, with several promising avenues for future investigation.
|
419 |
-
|
420 |
-
**Implications:**
|
421 |
-
These findings have broad applications across multiple disciplines and may lead to significant advances in the field.""",
|
422 |
-
|
423 |
-
'creative': f"""**{prompt[:30]}...**
|
424 |
-
|
425 |
-
The story unfolds in a world where imagination meets reality, where every character carries the weight of their dreams and the burden of their choices.
|
426 |
-
|
427 |
-
*Chapter 1: The Beginning*
|
428 |
-
|
429 |
-
In the quiet moments before dawn, when the world holds its breath between night and day, our tale begins. The protagonist stands at the threshold of an adventure that will challenge everything they thought they knew about themselves and the world around them.
|
430 |
-
|
431 |
-
The narrative weaves through layers of meaning, exploring themes of identity, purpose, and the delicate balance between hope and reality. Each scene is crafted with careful attention to emotional resonance and character development.
|
432 |
-
|
433 |
-
*As the story progresses, we discover that the true journey is not external, but internalβa transformation of the soul that mirrors the changing landscape of the world itself.*
|
434 |
-
|
435 |
-
**Themes Explored:**
|
436 |
-
β’ Personal growth and self-discovery
|
437 |
-
β’ The power of resilience and determination
|
438 |
-
β’ The complexity of human relationships
|
439 |
-
β’ The intersection of dreams and reality""",
|
440 |
-
|
441 |
-
'business': f"""**Strategic Analysis: {prompt[:50]}...**
|
442 |
-
|
443 |
-
**Executive Summary:**
|
444 |
-
This comprehensive analysis examines the strategic implications and market opportunities related to the identified business challenge.
|
445 |
-
|
446 |
-
**Market Assessment:**
|
447 |
-
β’ Current market size and growth projections
|
448 |
-
β’ Competitive landscape analysis
|
449 |
-
β’ Key trends and disruption factors
|
450 |
-
β’ Customer segment identification
|
451 |
-
|
452 |
-
**Strategic Recommendations:**
|
453 |
-
1. **Short-term actions** (0-6 months)
|
454 |
-
- Immediate market positioning
|
455 |
-
- Resource allocation optimization
|
456 |
-
- Risk mitigation strategies
|
457 |
-
|
458 |
-
2. **Medium-term initiatives** (6-18 months)
|
459 |
-
- Strategic partnerships and alliances
|
460 |
-
- Product/service development
|
461 |
-
- Market expansion opportunities
|
462 |
-
|
463 |
-
3. **Long-term vision** (18+ months)
|
464 |
-
- Innovation and R&D investment
|
465 |
-
- Scalability and sustainability
|
466 |
-
- Market leadership positioning
|
467 |
-
|
468 |
-
**Financial Projections:**
|
469 |
-
Based on conservative estimates, implementation of these strategies could result in significant ROI and market share growth.""",
|
470 |
-
|
471 |
-
'general': f"""**Comprehensive Response to: "{prompt[:50]}..."**
|
472 |
-
|
473 |
-
Thank you for your inquiry. Based on available knowledge and expertise from {routing_info['total_active']} specialized domains, here's a comprehensive analysis:
|
474 |
-
|
475 |
-
**Key Points:**
|
476 |
-
β’ The topic involves multiple interconnected factors that require careful consideration
|
477 |
-
β’ Current understanding is based on established principles and ongoing research
|
478 |
-
β’ Practical applications vary depending on specific context and requirements
|
479 |
-
β’ Best practices emphasize a balanced, evidence-based approach
|
480 |
-
|
481 |
-
**Detailed Analysis:**
|
482 |
-
The subject matter encompasses several important dimensions that merit thorough examination. Each aspect contributes to a deeper understanding of the overall concept and its implications.
|
483 |
-
|
484 |
-
**Practical Considerations:**
|
485 |
-
Implementation requires careful planning, adequate resources, and ongoing monitoring to ensure optimal outcomes. Success factors include stakeholder engagement, clear communication, and adaptive management strategies.
|
486 |
-
|
487 |
-
**Conclusion:**
|
488 |
-
This analysis provides a foundation for informed decision-making while acknowledging the complexity and nuanced nature of the topic."""
|
489 |
-
}
|
490 |
-
|
491 |
-
return domain_responses.get(domain, domain_responses['general'])
|
492 |
-
|
493 |
def generate_text(self, prompt: str, max_length: int = 100, temperature: float = 0.7,
|
494 |
top_p: float = 0.9, num_encoders: int = 5, show_routing: bool = True) -> Tuple[str, str]:
|
495 |
-
"""
|
496 |
-
Generate text with comprehensive error handling and routing information
|
497 |
-
|
498 |
-
Returns:
|
499 |
-
Tuple of (generated_text, routing_info_display)
|
500 |
-
"""
|
501 |
start_time = time.time()
|
502 |
|
503 |
# Update statistics
|
@@ -514,7 +454,7 @@ This analysis provides a foundation for informed decision-making while acknowled
|
|
514 |
# Real model generation
|
515 |
response = self._generate_real(prompt, max_length, temperature, top_p, num_encoders)
|
516 |
else:
|
517 |
-
# Simulated generation
|
518 |
response = self._simulate_generation(prompt, routing_info, max_length)
|
519 |
|
520 |
# Calculate performance metrics
|
@@ -546,46 +486,127 @@ This analysis provides a foundation for informed decision-making while acknowled
|
|
546 |
|
547 |
def _generate_real(self, prompt: str, max_length: int, temperature: float,
|
548 |
top_p: float, num_encoders: int) -> str:
|
549 |
-
"""Generate using real model"""
|
550 |
try:
|
551 |
# Encode input
|
552 |
inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
553 |
|
554 |
-
# Adjust number of active encoders
|
555 |
if hasattr(self.model, 'set_active_encoders'):
|
556 |
-
self.
|
|
|
557 |
|
558 |
# Generate with memory optimization
|
559 |
with torch.no_grad():
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
570 |
|
571 |
# Decode output
|
572 |
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
573 |
|
574 |
# Remove input prompt from output
|
575 |
-
|
|
|
|
|
|
|
576 |
|
577 |
return response if response else "Generated response was empty."
|
578 |
|
579 |
except torch.cuda.OutOfMemoryError:
|
580 |
logger.error("CUDA out of memory during generation")
|
581 |
-
return "Error: GPU memory insufficient. Try reducing max_length or
|
582 |
except Exception as e:
|
583 |
logger.error(f"Real generation error: {e}")
|
584 |
-
return f"Generation error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
|
586 |
def _create_routing_display(self, routing_info: Dict, generation_time: float,
|
587 |
estimated_tokens: int) -> str:
|
588 |
"""Create rich routing information display"""
|
|
|
|
|
|
|
589 |
return f"""
|
590 |
## π§ Intelligent Routing Analysis
|
591 |
|
@@ -594,10 +615,11 @@ This analysis provides a foundation for informed decision-making while acknowled
|
|
594 |
- **Confidence**: {routing_info['domain_confidence']:.1%}
|
595 |
- **Specialization Level**: {'High' if routing_info['domain_confidence'] > 0.7 else 'Medium' if routing_info['domain_confidence'] > 0.4 else 'General'}
|
596 |
|
597 |
-
**β‘
|
598 |
-
- **
|
599 |
-
- **
|
600 |
-
- **
|
|
|
601 |
|
602 |
**π’ Selected Encoder IDs:**
|
603 |
{', '.join(map(str, routing_info['selected_encoders'][:15]))}{'...' if len(routing_info['selected_encoders']) > 15 else ''}
|
@@ -606,15 +628,15 @@ This analysis provides a foundation for informed decision-making while acknowled
|
|
606 |
- **Generation Time**: {generation_time:.2f}s
|
607 |
- **Estimated Tokens**: {estimated_tokens}
|
608 |
- **Tokens/Second**: {estimated_tokens/generation_time:.1f}
|
609 |
-
- **
|
610 |
|
611 |
**ποΈ Confidence Scores (Top 5):**
|
612 |
{', '.join([f'{score:.3f}' for score in routing_info['confidence_scores'][:5]])}{'...' if len(routing_info['confidence_scores']) > 5 else ''}
|
613 |
|
614 |
**π‘ Optimization Notes:**
|
615 |
- Encoder selection optimized for domain: {routing_info['detected_domain']}
|
|
|
616 |
- Dynamic load balancing across {routing_info['total_active']} active encoders
|
617 |
-
- Confidence-weighted aggregation applied
|
618 |
"""
|
619 |
|
620 |
def get_model_info(self) -> str:
|
@@ -628,21 +650,39 @@ This analysis provides a foundation for informed decision-making while acknowled
|
|
628 |
if torch.cuda.is_available():
|
629 |
gpu_info = f"{torch.cuda.get_device_name(0)} ({torch.cuda.get_device_properties(0).total_memory // 1024**3}GB)"
|
630 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
631 |
return f"""
|
632 |
**π€ Mamba Encoder Swarm Model Information**
|
633 |
|
634 |
**Model Configuration:**
|
635 |
-
- **Status**: {
|
636 |
- **Active Encoders**: {getattr(self.model, 'num_active_encoders', 'N/A')}
|
637 |
-
- **Max Encoders**: {self.config
|
638 |
-
- **Model Dimension**: {self.config.
|
639 |
-
- **Vocabulary Size**: {self.config
|
640 |
- **Max Sequence Length**: {getattr(self.config, 'max_sequence_length', 'N/A')}
|
641 |
-
|
642 |
**System Information:**
|
643 |
- **Device**: {self.device} {f'({gpu_info})' if gpu_info != 'N/A' else ''}
|
644 |
- **RAM Usage**: {memory_info.percent:.1f}% ({memory_info.used // 1024**3}GB / {memory_info.total // 1024**3}GB)
|
645 |
-
- **
|
646 |
|
647 |
**Performance Statistics:**
|
648 |
- **Total Requests**: {self.stats['total_requests']}
|
@@ -652,23 +692,83 @@ This analysis provides a foundation for informed decision-making while acknowled
|
|
652 |
- **Avg Generation Time**: {self.stats['avg_generation_time']:.2f}s
|
653 |
- **Total Tokens Generated**: {self.stats['total_tokens_generated']:,}
|
654 |
|
655 |
-
**
|
656 |
"""
|
657 |
|
658 |
def get_system_status(self) -> Dict[str, Any]:
|
659 |
"""Get system status for monitoring"""
|
660 |
return {
|
661 |
'model_loaded': self.model_loaded,
|
|
|
662 |
'fallback_mode': self.fallback_mode,
|
663 |
'device': str(self.device),
|
664 |
'stats': self.stats.copy(),
|
665 |
'timestamp': datetime.now().isoformat()
|
666 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
667 |
|
668 |
def create_production_demo() -> gr.Blocks:
|
669 |
-
"""Create production-ready Gradio interface"""
|
670 |
|
671 |
-
# Initialize demo with
|
672 |
try:
|
673 |
demo_instance = MambaSwarmDemo(model_path="./", fallback_mode=False)
|
674 |
except Exception as e:
|
@@ -684,9 +784,13 @@ def create_production_demo() -> gr.Blocks:
|
|
684 |
def refresh_model_info():
|
685 |
return demo_instance.get_model_info()
|
686 |
|
|
|
|
|
|
|
|
|
687 |
# Create interface
|
688 |
with gr.Blocks(
|
689 |
-
title="Mamba Encoder Swarm - Production Demo",
|
690 |
theme=gr.themes.Soft(),
|
691 |
css="""
|
692 |
.gradio-container {
|
@@ -705,6 +809,13 @@ def create_production_demo() -> gr.Blocks:
|
|
705 |
padding: 15px;
|
706 |
margin: 10px 0;
|
707 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
708 |
"""
|
709 |
) as demo:
|
710 |
|
@@ -712,18 +823,29 @@ def create_production_demo() -> gr.Blocks:
|
|
712 |
gr.Markdown("""
|
713 |
# π Mamba Encoder Swarm - Production Demo
|
714 |
|
715 |
-
**Advanced Language Model with
|
716 |
|
717 |
-
|
718 |
-
|
719 |
""")
|
720 |
|
721 |
# Status indicator
|
722 |
with gr.Row():
|
723 |
-
with gr.Column(scale=
|
|
|
724 |
status_indicator = gr.Markdown(
|
725 |
-
f"**Status**: {
|
|
|
726 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
727 |
|
728 |
with gr.Row():
|
729 |
# Left column - Input and controls
|
@@ -803,7 +925,14 @@ def create_production_demo() -> gr.Blocks:
|
|
803 |
value=show_model_info(),
|
804 |
elem_classes=["model-info"]
|
805 |
)
|
806 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
807 |
|
808 |
# Examples section
|
809 |
with gr.Accordion("π‘ Example Prompts", open=True):
|
@@ -816,7 +945,8 @@ def create_production_demo() -> gr.Blocks:
|
|
816 |
["Analyze the legal implications of AI-generated content", 350, 0.7, 0.9, 15, True],
|
817 |
["Write a creative short story about a time-traveling scientist", 400, 0.9, 0.95, 12, True],
|
818 |
["Develop a marketing strategy for a sustainable fashion startup", 300, 0.8, 0.9, 10, True],
|
819 |
-
["How does quantum entanglement work and what are its applications?", 350, 0.6, 0.9, 15, True]
|
|
|
820 |
]
|
821 |
|
822 |
gr.Examples(
|
@@ -828,6 +958,28 @@ def create_production_demo() -> gr.Blocks:
|
|
828 |
label="Click any example to load it"
|
829 |
)
|
830 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
831 |
# Event handlers
|
832 |
generate_btn.click(
|
833 |
fn=generate_response,
|
@@ -841,15 +993,36 @@ def create_production_demo() -> gr.Blocks:
|
|
841 |
outputs=model_info_display
|
842 |
)
|
843 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
844 |
# Footer
|
845 |
gr.Markdown("""
|
846 |
---
|
847 |
-
### ποΈ Architecture Overview
|
|
|
|
|
|
|
|
|
|
|
|
|
848 |
|
849 |
**π§ Intelligent Routing System**
|
850 |
- Domain detection based on prompt analysis
|
851 |
- Dynamic encoder selection optimized for content type
|
852 |
- Load balancing across specialized encoder pools
|
|
|
853 |
|
854 |
**π§ Production Features**
|
855 |
- Comprehensive error handling and fallback modes
|
@@ -861,7 +1034,7 @@ def create_production_demo() -> gr.Blocks:
|
|
861 |
- **Medical & Healthcare** β’ **Legal & Regulatory** β’ **Code & Technical**
|
862 |
- **Science & Research** β’ **Creative Writing** β’ **Business & Finance**
|
863 |
|
864 |
-
Built with β€οΈ using Gradio, PyTorch, and the Mamba architecture
|
865 |
""")
|
866 |
|
867 |
return demo
|
@@ -918,4 +1091,4 @@ if __name__ == "__main__":
|
|
918 |
demo.launch(share=False, debug=False)
|
919 |
except Exception as e2:
|
920 |
logger.error(f"Minimal launch also failed: {e2}")
|
921 |
-
print(f"β All launch attempts failed. Error: {e2}")
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
"""
|
3 |
+
Enhanced Production-Ready Mamba Encoder Swarm Demo
|
4 |
+
Integrates pretrained Mamba weights from HuggingFace with swarm architecture
|
5 |
"""
|
6 |
|
7 |
import gradio as gr
|
|
|
14 |
import psutil
|
15 |
from typing import Optional, Dict, Any, Tuple
|
16 |
from datetime import datetime
|
17 |
+
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
|
18 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
19 |
|
20 |
# Setup comprehensive logging
|
21 |
logging.basicConfig(
|
|
|
28 |
)
|
29 |
logger = logging.getLogger(__name__)
|
30 |
|
31 |
+
class MambaWeightLoader:
|
32 |
+
"""Dynamic loader for pretrained Mamba weights"""
|
33 |
+
|
34 |
+
def __init__(self, model_name="state-spaces/mamba-130m"):
|
35 |
+
self.model_name = model_name
|
36 |
+
self.cache_dir = "/tmp/mamba_cache" if os.path.exists("/tmp") else "./mamba_cache"
|
37 |
+
self.model = None
|
38 |
+
self.tokenizer = None
|
39 |
+
self.config = None
|
40 |
+
|
41 |
+
def download_and_load(self):
|
42 |
+
"""Download and load Mamba weights in HuggingFace Spaces"""
|
43 |
+
try:
|
44 |
+
logger.info(f"π Loading pretrained model: {self.model_name}")
|
45 |
+
|
46 |
+
# Create cache directory
|
47 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
48 |
+
|
49 |
+
# Load tokenizer (lightweight)
|
50 |
+
logger.info("π Loading tokenizer...")
|
51 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
52 |
+
self.model_name,
|
53 |
+
cache_dir=self.cache_dir,
|
54 |
+
trust_remote_code=True
|
55 |
+
)
|
56 |
+
|
57 |
+
# Handle tokenizer padding
|
58 |
+
if self.tokenizer.pad_token is None:
|
59 |
+
if self.tokenizer.eos_token is not None:
|
60 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
61 |
+
else:
|
62 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
63 |
+
|
64 |
+
# Load configuration
|
65 |
+
logger.info("βοΈ Loading model configuration...")
|
66 |
+
self.config = AutoConfig.from_pretrained(
|
67 |
+
self.model_name,
|
68 |
+
cache_dir=self.cache_dir,
|
69 |
+
trust_remote_code=True
|
70 |
+
)
|
71 |
+
|
72 |
+
# Load model with optimizations for Spaces
|
73 |
+
logger.info("π§ Loading model weights...")
|
74 |
+
|
75 |
+
# Determine optimal dtype and device settings
|
76 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
77 |
+
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
78 |
+
|
79 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
80 |
+
self.model_name,
|
81 |
+
config=self.config,
|
82 |
+
cache_dir=self.cache_dir,
|
83 |
+
trust_remote_code=True,
|
84 |
+
torch_dtype=dtype,
|
85 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
86 |
+
low_cpu_mem_usage=True
|
87 |
+
)
|
88 |
+
|
89 |
+
# Move to device if not using device_map
|
90 |
+
if not torch.cuda.is_available():
|
91 |
+
self.model.to(device)
|
92 |
+
|
93 |
+
self.model.eval()
|
94 |
+
|
95 |
+
# Log model info
|
96 |
+
num_params = sum(p.numel() for p in self.model.parameters())
|
97 |
+
logger.info(f"β
Model loaded successfully!")
|
98 |
+
logger.info(f"π Parameters: {num_params:,} ({num_params/1e6:.1f}M)")
|
99 |
+
logger.info(f"π§ Device: {device}, dtype: {dtype}")
|
100 |
+
|
101 |
+
return True
|
102 |
+
|
103 |
+
except Exception as e:
|
104 |
+
logger.error(f"β Error loading pretrained model: {e}")
|
105 |
+
return False
|
106 |
+
|
107 |
+
def get_model_info(self):
|
108 |
+
"""Get model information"""
|
109 |
+
if self.model:
|
110 |
+
try:
|
111 |
+
num_params = sum(p.numel() for p in self.model.parameters())
|
112 |
+
device = next(self.model.parameters()).device
|
113 |
+
dtype = next(self.model.parameters()).dtype
|
114 |
+
|
115 |
+
return {
|
116 |
+
"name": self.model_name,
|
117 |
+
"parameters": f"{num_params:,}",
|
118 |
+
"parameters_millions": f"{num_params/1e6:.1f}M",
|
119 |
+
"device": str(device),
|
120 |
+
"dtype": str(dtype),
|
121 |
+
"vocab_size": getattr(self.config, 'vocab_size', 'Unknown'),
|
122 |
+
"hidden_size": getattr(self.config, 'd_model', getattr(self.config, 'hidden_size', 'Unknown'))
|
123 |
+
}
|
124 |
+
except Exception as e:
|
125 |
+
logger.error(f"Error getting model info: {e}")
|
126 |
+
return {"error": str(e)}
|
127 |
+
return None
|
128 |
+
|
129 |
class MambaSwarmDemo:
|
130 |
+
"""Enhanced Production-ready Mamba Swarm Demo with dynamic pretrained weight loading"""
|
131 |
|
132 |
def __init__(self, model_path: str = "./", fallback_mode: bool = False):
|
133 |
self.model = None
|
|
|
137 |
self.model_path = model_path
|
138 |
self.fallback_mode = fallback_mode
|
139 |
self.model_loaded = False
|
140 |
+
self.pretrained_loader = None
|
141 |
+
self.using_pretrained = False
|
142 |
|
143 |
# Performance tracking
|
144 |
self.stats = {
|
|
|
161 |
}
|
162 |
|
163 |
self._initialize_model()
|
164 |
+
logger.info(f"Demo initialized - Model loaded: {self.model_loaded}, Using pretrained: {self.using_pretrained}, Fallback mode: {self.fallback_mode}")
|
165 |
|
166 |
def _initialize_model(self):
|
167 |
+
"""Initialize model with pretrained weights or fallback"""
|
168 |
try:
|
169 |
+
logger.info("π Attempting to load model with priority: Pretrained -> Custom -> Fallback")
|
170 |
|
171 |
+
# Try to load pretrained model first (highest priority)
|
172 |
+
success = self._load_pretrained_model()
|
|
|
|
|
|
|
173 |
|
174 |
+
if not success:
|
175 |
+
logger.info("Pretrained loading failed, trying custom swarm model...")
|
176 |
+
success = self._load_custom_swarm_model()
|
177 |
+
|
178 |
+
if not success:
|
179 |
+
logger.info("All model loading attempts failed, enabling fallback mode")
|
180 |
+
self.fallback_mode = True
|
181 |
self._initialize_fallback_mode()
|
182 |
|
183 |
except Exception as e:
|
|
|
186 |
self.fallback_mode = True
|
187 |
self._initialize_fallback_mode()
|
188 |
|
189 |
+
def _load_pretrained_model(self):
|
190 |
+
"""Load pretrained Mamba model from HuggingFace with automatic model selection"""
|
191 |
+
try:
|
192 |
+
# Choose model based on available resources
|
193 |
+
MODEL_OPTIONS = {
|
194 |
+
"small": "state-spaces/mamba-130m", # ~500MB
|
195 |
+
"medium": "state-spaces/mamba-790m", # ~3GB
|
196 |
+
"large": "state-spaces/mamba-1.4b", # ~5GB
|
197 |
+
"xl": "state-spaces/mamba-2.8b", # ~10GB
|
198 |
+
}
|
199 |
+
|
200 |
+
# Auto-select model based on available memory
|
201 |
+
memory_gb = psutil.virtual_memory().total / (1024**3)
|
202 |
+
if memory_gb >= 32 and torch.cuda.is_available():
|
203 |
+
selected_model = MODEL_OPTIONS["xl"]
|
204 |
+
elif memory_gb >= 16 and torch.cuda.is_available():
|
205 |
+
selected_model = MODEL_OPTIONS["large"]
|
206 |
+
elif memory_gb >= 8:
|
207 |
+
selected_model = MODEL_OPTIONS["medium"]
|
208 |
+
else:
|
209 |
+
selected_model = MODEL_OPTIONS["small"]
|
210 |
+
|
211 |
+
logger.info(f"π― Auto-selected model: {selected_model} (Available memory: {memory_gb:.1f}GB)")
|
212 |
+
|
213 |
+
# Initialize loader
|
214 |
+
self.pretrained_loader = MambaWeightLoader(selected_model)
|
215 |
+
|
216 |
+
# Download and load
|
217 |
+
if self.pretrained_loader.download_and_load():
|
218 |
+
self.model = self.pretrained_loader.model
|
219 |
+
self.tokenizer = self.pretrained_loader.tokenizer
|
220 |
+
self.config = self.pretrained_loader.config
|
221 |
+
self.model_loaded = True
|
222 |
+
self.using_pretrained = True
|
223 |
+
|
224 |
+
logger.info("β
Pretrained model loaded successfully!")
|
225 |
+
return True
|
226 |
+
else:
|
227 |
+
logger.warning("β Pretrained model loading failed")
|
228 |
+
return False
|
229 |
+
|
230 |
+
except Exception as e:
|
231 |
+
logger.error(f"Pretrained model loading error: {e}")
|
232 |
+
return False
|
233 |
+
|
234 |
+
def _load_custom_swarm_model(self):
|
235 |
+
"""Try to load custom swarm model implementation"""
|
236 |
try:
|
237 |
+
logger.info("Attempting to load custom Mamba Swarm model...")
|
238 |
+
|
239 |
+
# Try multiple import paths for the custom model
|
240 |
model_class = None
|
241 |
|
|
|
242 |
try:
|
243 |
from modeling_mamba_swarm import MambaSwarmForCausalLM
|
244 |
model_class = MambaSwarmForCausalLM
|
245 |
+
logger.info("Found MambaSwarmForCausalLM")
|
246 |
except ImportError:
|
247 |
try:
|
248 |
+
from core.mamba_swarm_integration import MambaEncoderSwarmModel
|
249 |
+
model_class = MambaEncoderSwarmModel
|
250 |
+
logger.info("Found MambaEncoderSwarmModel")
|
251 |
except ImportError:
|
252 |
try:
|
253 |
+
from system.mambaSwarm import UnifiedMambaSwarm
|
254 |
+
# Use the unified swarm in native mode
|
255 |
+
swarm = UnifiedMambaSwarm(use_pretrained=False)
|
256 |
+
if hasattr(swarm, 'native_swarm_model') and swarm.native_swarm_model:
|
257 |
+
self.model = swarm.native_swarm_model
|
258 |
+
self.model_loaded = True
|
259 |
+
logger.info("Loaded native swarm model from UnifiedMambaSwarm")
|
260 |
+
return True
|
261 |
+
else:
|
262 |
+
raise ImportError("No native swarm model available")
|
263 |
except ImportError:
|
264 |
+
logger.warning("No custom swarm model found")
|
265 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
if model_class is None:
|
268 |
+
return False
|
269 |
|
270 |
+
# Create configuration for custom model
|
271 |
try:
|
272 |
+
from modeling_mamba_swarm import MambaSwarmConfig
|
273 |
+
self.config = MambaSwarmConfig(
|
274 |
+
num_encoders=8,
|
275 |
+
max_mamba_encoders=100,
|
276 |
+
d_model=768,
|
277 |
+
vocab_size=50257,
|
278 |
+
max_sequence_length=2048
|
279 |
+
)
|
280 |
+
except ImportError:
|
281 |
+
# Fallback config
|
282 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
from core.config import MambaConfig
|
284 |
self.config = MambaConfig()
|
|
|
285 |
self.config.num_encoders = 8
|
286 |
self.config.max_mamba_encoders = 100
|
287 |
+
except ImportError:
|
288 |
+
# Create minimal config
|
289 |
+
self.config = type('Config', (), {
|
290 |
+
'num_encoders': 8,
|
291 |
+
'max_mamba_encoders': 100,
|
292 |
+
'd_model': 768,
|
293 |
+
'vocab_size': 50257,
|
294 |
+
'max_sequence_length': 2048
|
295 |
+
})()
|
296 |
|
297 |
+
# Initialize custom model
|
298 |
+
if model_class.__name__ == 'MambaEncoderSwarmModel':
|
299 |
+
self.model = model_class(self.config, num_encoders=8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
else:
|
301 |
+
self.model = model_class(self.config)
|
302 |
+
|
303 |
+
# Create tokenizer
|
304 |
+
from transformers import GPT2Tokenizer
|
305 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
306 |
+
if self.tokenizer.pad_token is None:
|
307 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
|
|
|
|
|
|
308 |
|
309 |
self.model.to(self.device)
|
310 |
self.model.eval()
|
311 |
self.model_loaded = True
|
312 |
|
313 |
+
logger.info("β
Custom swarm model loaded successfully!")
|
314 |
+
return True
|
|
|
|
|
315 |
|
316 |
except Exception as e:
|
317 |
+
logger.error(f"Custom model loading error: {e}")
|
318 |
+
return False
|
319 |
|
320 |
def _initialize_fallback_mode(self):
|
321 |
"""Initialize fallback/simulation mode"""
|
|
|
350 |
self.eos_token = "[EOS]"
|
351 |
|
352 |
def encode(self, text, return_tensors=None):
|
|
|
353 |
tokens = text.split()
|
354 |
token_ids = [hash(token) % 1000 for token in tokens]
|
355 |
if return_tensors == "pt":
|
|
|
357 |
return token_ids
|
358 |
|
359 |
def decode(self, token_ids, skip_special_tokens=True):
|
|
|
360 |
return f"Generated response for {len(token_ids)} tokens"
|
361 |
|
362 |
self.tokenizer = MockTokenizer()
|
|
|
412 |
available_encoders = list(range(start, min(end + 1, 101)))
|
413 |
|
414 |
# Select encoders based on prompt complexity and domain
|
415 |
+
prompt_complexity = min(len(prompt.split()) / 10, 3.0)
|
416 |
optimal_count = min(max(int(num_encoders * (1 + prompt_complexity)), 3), 25)
|
417 |
|
418 |
if len(available_encoders) >= optimal_count:
|
|
|
435 |
'total_active': len(selected_encoders)
|
436 |
}
|
437 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
def generate_text(self, prompt: str, max_length: int = 100, temperature: float = 0.7,
|
439 |
top_p: float = 0.9, num_encoders: int = 5, show_routing: bool = True) -> Tuple[str, str]:
|
440 |
+
"""Generate text with comprehensive error handling and routing information"""
|
|
|
|
|
|
|
|
|
|
|
441 |
start_time = time.time()
|
442 |
|
443 |
# Update statistics
|
|
|
454 |
# Real model generation
|
455 |
response = self._generate_real(prompt, max_length, temperature, top_p, num_encoders)
|
456 |
else:
|
457 |
+
# Simulated generation
|
458 |
response = self._simulate_generation(prompt, routing_info, max_length)
|
459 |
|
460 |
# Calculate performance metrics
|
|
|
486 |
|
487 |
def _generate_real(self, prompt: str, max_length: int, temperature: float,
|
488 |
top_p: float, num_encoders: int) -> str:
|
489 |
+
"""Generate using real pretrained model"""
|
490 |
try:
|
491 |
# Encode input
|
492 |
inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
493 |
|
494 |
+
# Adjust number of active encoders (if supported)
|
495 |
if hasattr(self.model, 'set_active_encoders'):
|
496 |
+
max_encoders = getattr(self.config, 'max_mamba_encoders', 100)
|
497 |
+
self.model.set_active_encoders(min(num_encoders, max_encoders))
|
498 |
|
499 |
# Generate with memory optimization
|
500 |
with torch.no_grad():
|
501 |
+
try:
|
502 |
+
outputs = self.model.generate(
|
503 |
+
inputs,
|
504 |
+
max_new_tokens=min(max_length, 512), # Limit for stability
|
505 |
+
temperature=temperature,
|
506 |
+
top_p=top_p,
|
507 |
+
do_sample=True,
|
508 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
509 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
510 |
+
use_cache=True,
|
511 |
+
attention_mask=torch.ones_like(inputs) # Ensure attention mask
|
512 |
+
)
|
513 |
+
except Exception as gen_error:
|
514 |
+
logger.warning(f"Generation with parameters failed: {gen_error}")
|
515 |
+
# Fallback to simpler generation
|
516 |
+
outputs = self.model.generate(
|
517 |
+
inputs,
|
518 |
+
max_new_tokens=min(max_length, 256),
|
519 |
+
do_sample=False, # Use greedy decoding as fallback
|
520 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
521 |
+
eos_token_id=self.tokenizer.eos_token_id
|
522 |
+
)
|
523 |
|
524 |
# Decode output
|
525 |
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
526 |
|
527 |
# Remove input prompt from output
|
528 |
+
if generated_text.startswith(prompt):
|
529 |
+
response = generated_text[len(prompt):].strip()
|
530 |
+
else:
|
531 |
+
response = generated_text.strip()
|
532 |
|
533 |
return response if response else "Generated response was empty."
|
534 |
|
535 |
except torch.cuda.OutOfMemoryError:
|
536 |
logger.error("CUDA out of memory during generation")
|
537 |
+
return "Error: GPU memory insufficient. Try reducing max_length or switching to CPU mode."
|
538 |
except Exception as e:
|
539 |
logger.error(f"Real generation error: {e}")
|
540 |
+
return f"Generation error: {str(e)}. Using pretrained model in fallback mode."
|
541 |
+
|
542 |
+
def _simulate_generation(self, prompt: str, routing_info: Dict, max_length: int) -> str:
|
543 |
+
"""Generate sophisticated simulated responses"""
|
544 |
+
domain = routing_info['detected_domain']
|
545 |
+
|
546 |
+
# Enhanced domain-specific responses
|
547 |
+
if domain == 'code':
|
548 |
+
return f"""Here's a comprehensive solution for your request:
|
549 |
+
|
550 |
+
```python
|
551 |
+
def solution(input_data):
|
552 |
+
\"\"\"
|
553 |
+
Optimized implementation based on your requirements
|
554 |
+
\"\"\"
|
555 |
+
try:
|
556 |
+
# Input validation
|
557 |
+
if not input_data:
|
558 |
+
raise ValueError("Input cannot be empty")
|
559 |
+
|
560 |
+
# Process the data
|
561 |
+
result = process_input(input_data)
|
562 |
+
|
563 |
+
return result
|
564 |
+
except Exception as e:
|
565 |
+
print(f"Error: {{e}}")
|
566 |
+
return None
|
567 |
+
|
568 |
+
def process_input(data):
|
569 |
+
# Implementation here
|
570 |
+
return processed_data
|
571 |
+
```
|
572 |
+
|
573 |
+
This solution includes error handling, input validation, and follows best practices for production code."""
|
574 |
+
|
575 |
+
elif domain == 'medical':
|
576 |
+
return f"""Based on current medical knowledge regarding your query:
|
577 |
+
|
578 |
+
**Overview:**
|
579 |
+
This topic involves several important medical considerations that should be evaluated by healthcare professionals.
|
580 |
+
|
581 |
+
**Key Points:**
|
582 |
+
β’ Symptoms and presentation can vary significantly between individuals
|
583 |
+
β’ Early detection and proper diagnosis are crucial
|
584 |
+
β’ Treatment approaches should be personalized
|
585 |
+
β’ Regular monitoring may be recommended
|
586 |
+
|
587 |
+
**Important Note:** This information is for educational purposes only. Please consult with qualified healthcare professionals for personalized medical advice, diagnosis, and treatment recommendations."""
|
588 |
+
|
589 |
+
else:
|
590 |
+
return f"""**Response to: "{prompt[:50]}..."**
|
591 |
+
|
592 |
+
Based on analysis from {routing_info['total_active']} specialized encoders in the {domain} domain:
|
593 |
+
|
594 |
+
This is a comprehensive response that addresses your query with relevant information and insights. The analysis considers multiple perspectives and provides a balanced view of the topic.
|
595 |
+
|
596 |
+
**Key insights:**
|
597 |
+
β’ The topic involves several interconnected factors
|
598 |
+
β’ Current understanding is based on established principles
|
599 |
+
β’ Practical applications may vary depending on context
|
600 |
+
β’ Further exploration could yield additional insights
|
601 |
+
|
602 |
+
**Domain expertise applied:** {domain.title()} specialization with {routing_info['domain_confidence']:.1%} confidence."""
|
603 |
|
604 |
def _create_routing_display(self, routing_info: Dict, generation_time: float,
|
605 |
estimated_tokens: int) -> str:
|
606 |
"""Create rich routing information display"""
|
607 |
+
model_type = "Real Pretrained Model" if (self.model_loaded and not self.fallback_mode and self.using_pretrained) else "Custom Swarm Model" if (self.model_loaded and not self.fallback_mode) else "Simulation Mode"
|
608 |
+
model_name = getattr(self.pretrained_loader, 'model_name', 'Custom/Simulation') if self.pretrained_loader else 'Custom/Simulation'
|
609 |
+
|
610 |
return f"""
|
611 |
## π§ Intelligent Routing Analysis
|
612 |
|
|
|
615 |
- **Confidence**: {routing_info['domain_confidence']:.1%}
|
616 |
- **Specialization Level**: {'High' if routing_info['domain_confidence'] > 0.7 else 'Medium' if routing_info['domain_confidence'] > 0.4 else 'General'}
|
617 |
|
618 |
+
**β‘ Model Information:**
|
619 |
+
- **Model Type**: {model_type}
|
620 |
+
- **Base Model**: {model_name}
|
621 |
+
- **Active Encoders**: {routing_info['total_active']}/{getattr(self.config, 'max_mamba_encoders', 100)}
|
622 |
+
- **Device**: {self.device}
|
623 |
|
624 |
**π’ Selected Encoder IDs:**
|
625 |
{', '.join(map(str, routing_info['selected_encoders'][:15]))}{'...' if len(routing_info['selected_encoders']) > 15 else ''}
|
|
|
628 |
- **Generation Time**: {generation_time:.2f}s
|
629 |
- **Estimated Tokens**: {estimated_tokens}
|
630 |
- **Tokens/Second**: {estimated_tokens/generation_time:.1f}
|
631 |
+
- **Success Rate**: {(self.stats['successful_generations'] / max(self.stats['total_requests'], 1) * 100):.1f}%
|
632 |
|
633 |
**ποΈ Confidence Scores (Top 5):**
|
634 |
{', '.join([f'{score:.3f}' for score in routing_info['confidence_scores'][:5]])}{'...' if len(routing_info['confidence_scores']) > 5 else ''}
|
635 |
|
636 |
**π‘ Optimization Notes:**
|
637 |
- Encoder selection optimized for domain: {routing_info['detected_domain']}
|
638 |
+
- {'Pretrained weights from HuggingFace' if self.using_pretrained else 'Custom swarm implementation' if self.model_loaded and not self.fallback_mode else 'Simulation mode active'}
|
639 |
- Dynamic load balancing across {routing_info['total_active']} active encoders
|
|
|
640 |
"""
|
641 |
|
642 |
def get_model_info(self) -> str:
|
|
|
650 |
if torch.cuda.is_available():
|
651 |
gpu_info = f"{torch.cuda.get_device_name(0)} ({torch.cuda.get_device_properties(0).total_memory // 1024**3}GB)"
|
652 |
|
653 |
+
# Get pretrained model info if available
|
654 |
+
pretrained_info = ""
|
655 |
+
if self.pretrained_loader:
|
656 |
+
model_info = self.pretrained_loader.get_model_info()
|
657 |
+
if model_info and 'error' not in model_info:
|
658 |
+
pretrained_info = f"""
|
659 |
+
**π€ Pretrained Model Details:**
|
660 |
+
- **Model Name**: {model_info['name']}
|
661 |
+
- **Parameters**: {model_info['parameters']} ({model_info['parameters_millions']})
|
662 |
+
- **Vocabulary Size**: {model_info['vocab_size']:,}
|
663 |
+
- **Hidden Size**: {model_info['hidden_size']}
|
664 |
+
- **Model Device**: {model_info['device']}
|
665 |
+
- **Data Type**: {model_info['dtype']}
|
666 |
+
"""
|
667 |
+
|
668 |
+
status_emoji = "β
" if self.model_loaded and not self.fallback_mode else "β οΈ"
|
669 |
+
status_text = f"Loaded {'with Pretrained Weights' if self.using_pretrained else 'with Custom Swarm'}" if self.model_loaded and not self.fallback_mode else "Simulation Mode"
|
670 |
+
|
671 |
return f"""
|
672 |
**π€ Mamba Encoder Swarm Model Information**
|
673 |
|
674 |
**Model Configuration:**
|
675 |
+
- **Status**: {status_emoji} {status_text}
|
676 |
- **Active Encoders**: {getattr(self.model, 'num_active_encoders', 'N/A')}
|
677 |
+
- **Max Encoders**: {getattr(self.config, 'max_mamba_encoders', 100)}
|
678 |
+
- **Model Dimension**: {getattr(self.config, 'd_model', getattr(self.config, 'hidden_size', 768))}
|
679 |
+
- **Vocabulary Size**: {getattr(self.config, 'vocab_size', 50257):,}
|
680 |
- **Max Sequence Length**: {getattr(self.config, 'max_sequence_length', 'N/A')}
|
681 |
+
{pretrained_info}
|
682 |
**System Information:**
|
683 |
- **Device**: {self.device} {f'({gpu_info})' if gpu_info != 'N/A' else ''}
|
684 |
- **RAM Usage**: {memory_info.percent:.1f}% ({memory_info.used // 1024**3}GB / {memory_info.total // 1024**3}GB)
|
685 |
+
- **PyTorch Version**: {torch.__version__}
|
686 |
|
687 |
**Performance Statistics:**
|
688 |
- **Total Requests**: {self.stats['total_requests']}
|
|
|
692 |
- **Avg Generation Time**: {self.stats['avg_generation_time']:.2f}s
|
693 |
- **Total Tokens Generated**: {self.stats['total_tokens_generated']:,}
|
694 |
|
695 |
+
**Mode**: {'π’ Pretrained Model Active' if self.using_pretrained else 'π΅ Custom Swarm Active' if self.model_loaded and not self.fallback_mode else 'π‘ Simulation Mode'}
|
696 |
"""
|
697 |
|
698 |
def get_system_status(self) -> Dict[str, Any]:
|
699 |
"""Get system status for monitoring"""
|
700 |
return {
|
701 |
'model_loaded': self.model_loaded,
|
702 |
+
'using_pretrained': self.using_pretrained,
|
703 |
'fallback_mode': self.fallback_mode,
|
704 |
'device': str(self.device),
|
705 |
'stats': self.stats.copy(),
|
706 |
'timestamp': datetime.now().isoformat()
|
707 |
}
|
708 |
+
|
709 |
+
def switch_model(self, model_size: str = "auto") -> str:
|
710 |
+
"""Switch between different pretrained model sizes"""
|
711 |
+
if not self.using_pretrained:
|
712 |
+
return "β Model switching only available when using pretrained models"
|
713 |
+
|
714 |
+
try:
|
715 |
+
MODEL_OPTIONS = {
|
716 |
+
"small": "state-spaces/mamba-130m",
|
717 |
+
"medium": "state-spaces/mamba-790m",
|
718 |
+
"large": "state-spaces/mamba-1.4b",
|
719 |
+
"xl": "state-spaces/mamba-2.8b"
|
720 |
+
}
|
721 |
+
|
722 |
+
if model_size == "auto":
|
723 |
+
# Auto-select based on memory
|
724 |
+
memory_gb = psutil.virtual_memory().total / (1024**3)
|
725 |
+
if memory_gb >= 32 and torch.cuda.is_available():
|
726 |
+
model_size = "xl"
|
727 |
+
elif memory_gb >= 16 and torch.cuda.is_available():
|
728 |
+
model_size = "large"
|
729 |
+
elif memory_gb >= 8:
|
730 |
+
model_size = "medium"
|
731 |
+
else:
|
732 |
+
model_size = "small"
|
733 |
+
|
734 |
+
if model_size not in MODEL_OPTIONS:
|
735 |
+
return f"β Invalid model size. Choose from: {list(MODEL_OPTIONS.keys())}"
|
736 |
+
|
737 |
+
selected_model = MODEL_OPTIONS[model_size]
|
738 |
+
|
739 |
+
# Check if already using this model
|
740 |
+
if self.pretrained_loader and self.pretrained_loader.model_name == selected_model:
|
741 |
+
return f"β
Already using {selected_model}"
|
742 |
+
|
743 |
+
logger.info(f"π Switching to model: {selected_model}")
|
744 |
+
|
745 |
+
# Clear current model
|
746 |
+
if self.model:
|
747 |
+
del self.model
|
748 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
749 |
+
|
750 |
+
# Load new model
|
751 |
+
self.pretrained_loader = MambaWeightLoader(selected_model)
|
752 |
+
|
753 |
+
if self.pretrained_loader.download_and_load():
|
754 |
+
self.model = self.pretrained_loader.model
|
755 |
+
self.tokenizer = self.pretrained_loader.tokenizer
|
756 |
+
self.config = self.pretrained_loader.config
|
757 |
+
|
758 |
+
logger.info(f"β
Successfully switched to {selected_model}")
|
759 |
+
return f"β
Successfully switched to {selected_model}"
|
760 |
+
else:
|
761 |
+
logger.error(f"β Failed to switch to {selected_model}")
|
762 |
+
return f"β Failed to switch to {selected_model}"
|
763 |
+
|
764 |
+
except Exception as e:
|
765 |
+
logger.error(f"Error switching model: {e}")
|
766 |
+
return f"β Error switching model: {str(e)}"
|
767 |
|
768 |
def create_production_demo() -> gr.Blocks:
|
769 |
+
"""Create production-ready Gradio interface with pretrained model support"""
|
770 |
|
771 |
+
# Initialize demo with pretrained model capability
|
772 |
try:
|
773 |
demo_instance = MambaSwarmDemo(model_path="./", fallback_mode=False)
|
774 |
except Exception as e:
|
|
|
784 |
def refresh_model_info():
|
785 |
return demo_instance.get_model_info()
|
786 |
|
787 |
+
def switch_model_size(model_size):
|
788 |
+
result = demo_instance.switch_model(model_size)
|
789 |
+
return result, demo_instance.get_model_info()
|
790 |
+
|
791 |
# Create interface
|
792 |
with gr.Blocks(
|
793 |
+
title="Mamba Encoder Swarm - Production Demo with Pretrained Weights",
|
794 |
theme=gr.themes.Soft(),
|
795 |
css="""
|
796 |
.gradio-container {
|
|
|
809 |
padding: 15px;
|
810 |
margin: 10px 0;
|
811 |
}
|
812 |
+
.status-indicator {
|
813 |
+
background-color: #d4edda;
|
814 |
+
border: 1px solid #c3e6cb;
|
815 |
+
border-radius: 8px;
|
816 |
+
padding: 10px;
|
817 |
+
margin: 10px 0;
|
818 |
+
}
|
819 |
"""
|
820 |
) as demo:
|
821 |
|
|
|
823 |
gr.Markdown("""
|
824 |
# π Mamba Encoder Swarm - Production Demo
|
825 |
|
826 |
+
**Advanced Language Model with Pretrained Weights & Dynamic Routing**
|
827 |
|
828 |
+
Now featuring **automatic pretrained weight loading** from HuggingFace's state-spaces Mamba models,
|
829 |
+
with intelligent domain-aware routing across up to 100 specialized encoders.
|
830 |
""")
|
831 |
|
832 |
# Status indicator
|
833 |
with gr.Row():
|
834 |
+
with gr.Column(scale=3):
|
835 |
+
status_text = f"π’ Real Pretrained Model" if demo_instance.using_pretrained else f"π΅ Custom Swarm Model" if demo_instance.model_loaded and not demo_instance.fallback_mode else "π‘ Simulation Mode"
|
836 |
status_indicator = gr.Markdown(
|
837 |
+
f"**Status**: {status_text}",
|
838 |
+
elem_classes=["status-indicator"]
|
839 |
)
|
840 |
+
with gr.Column(scale=1):
|
841 |
+
if demo_instance.using_pretrained:
|
842 |
+
model_switch = gr.Dropdown(
|
843 |
+
choices=["auto", "small", "medium", "large", "xl"],
|
844 |
+
value="auto",
|
845 |
+
label="π Switch Model",
|
846 |
+
info="Change pretrained model size"
|
847 |
+
)
|
848 |
+
switch_btn = gr.Button("Switch Model", variant="secondary", size="sm")
|
849 |
|
850 |
with gr.Row():
|
851 |
# Left column - Input and controls
|
|
|
925 |
value=show_model_info(),
|
926 |
elem_classes=["model-info"]
|
927 |
)
|
928 |
+
with gr.Column(scale=1):
|
929 |
+
refresh_info_btn = gr.Button("π Refresh Info", size="sm")
|
930 |
+
if demo_instance.using_pretrained:
|
931 |
+
model_status = gr.Textbox(
|
932 |
+
label="Model Switch Status",
|
933 |
+
interactive=False,
|
934 |
+
lines=2
|
935 |
+
)
|
936 |
|
937 |
# Examples section
|
938 |
with gr.Accordion("π‘ Example Prompts", open=True):
|
|
|
945 |
["Analyze the legal implications of AI-generated content", 350, 0.7, 0.9, 15, True],
|
946 |
["Write a creative short story about a time-traveling scientist", 400, 0.9, 0.95, 12, True],
|
947 |
["Develop a marketing strategy for a sustainable fashion startup", 300, 0.8, 0.9, 10, True],
|
948 |
+
["How does quantum entanglement work and what are its applications?", 350, 0.6, 0.9, 15, True],
|
949 |
+
["Explain the economic impact of renewable energy adoption", 300, 0.7, 0.9, 12, True]
|
950 |
]
|
951 |
|
952 |
gr.Examples(
|
|
|
958 |
label="Click any example to load it"
|
959 |
)
|
960 |
|
961 |
+
# Advanced features section
|
962 |
+
with gr.Accordion("π¬ Advanced Features", open=False):
|
963 |
+
gr.Markdown("""
|
964 |
+
### π Pretrained Model Features
|
965 |
+
- **Automatic Model Selection**: Chooses optimal model size based on available memory
|
966 |
+
- **Dynamic Model Switching**: Switch between different Mamba model sizes
|
967 |
+
- **HuggingFace Integration**: Direct loading from state-spaces repository
|
968 |
+
- **Memory Optimization**: Efficient loading with half-precision and device mapping
|
969 |
+
|
970 |
+
### π§ Intelligent Routing System
|
971 |
+
- **Domain Detection**: Automatic classification of prompt domains
|
972 |
+
- **Specialized Encoders**: 100+ domain-specific encoder pools
|
973 |
+
- **Load Balancing**: Dynamic distribution across active encoders
|
974 |
+
- **Confidence Scoring**: Weighted aggregation based on encoder confidence
|
975 |
+
|
976 |
+
### π Model Sizes Available
|
977 |
+
- **Small (130M)**: ~500MB, good for basic tasks
|
978 |
+
- **Medium (790M)**: ~3GB, balanced performance
|
979 |
+
- **Large (1.4B)**: ~5GB, high-quality responses
|
980 |
+
- **XL (2.8B)**: ~10GB, best performance (requires 16GB+ RAM)
|
981 |
+
""")
|
982 |
+
|
983 |
# Event handlers
|
984 |
generate_btn.click(
|
985 |
fn=generate_response,
|
|
|
993 |
outputs=model_info_display
|
994 |
)
|
995 |
|
996 |
+
# Model switching event handler (only if using pretrained)
|
997 |
+
if demo_instance.using_pretrained:
|
998 |
+
switch_btn.click(
|
999 |
+
fn=switch_model_size,
|
1000 |
+
inputs=[model_switch],
|
1001 |
+
outputs=[model_status, model_info_display]
|
1002 |
+
)
|
1003 |
+
|
1004 |
+
# Auto-refresh status on page load
|
1005 |
+
demo.load(
|
1006 |
+
fn=lambda: (demo_instance.get_model_info(), f"**Status**: {'π’ Real Pretrained Model' if demo_instance.using_pretrained else 'π΅ Custom Swarm Model' if demo_instance.model_loaded and not demo_instance.fallback_mode else 'π‘ Simulation Mode'}"),
|
1007 |
+
outputs=[model_info_display, status_indicator]
|
1008 |
+
)
|
1009 |
+
|
1010 |
# Footer
|
1011 |
gr.Markdown("""
|
1012 |
---
|
1013 |
+
### ποΈ Enhanced Architecture Overview
|
1014 |
+
|
1015 |
+
**π€ Pretrained Integration**
|
1016 |
+
- Direct loading from HuggingFace state-spaces Mamba models
|
1017 |
+
- Automatic model size selection based on system resources
|
1018 |
+
- Seamless fallback to custom swarm implementation
|
1019 |
+
- Dynamic model switching without restart
|
1020 |
|
1021 |
**π§ Intelligent Routing System**
|
1022 |
- Domain detection based on prompt analysis
|
1023 |
- Dynamic encoder selection optimized for content type
|
1024 |
- Load balancing across specialized encoder pools
|
1025 |
+
- Confidence-weighted response aggregation
|
1026 |
|
1027 |
**π§ Production Features**
|
1028 |
- Comprehensive error handling and fallback modes
|
|
|
1034 |
- **Medical & Healthcare** β’ **Legal & Regulatory** β’ **Code & Technical**
|
1035 |
- **Science & Research** β’ **Creative Writing** β’ **Business & Finance**
|
1036 |
|
1037 |
+
Built with β€οΈ using Gradio, PyTorch, HuggingFace Transformers, and the Mamba architecture
|
1038 |
""")
|
1039 |
|
1040 |
return demo
|
|
|
1091 |
demo.launch(share=False, debug=False)
|
1092 |
except Exception as e2:
|
1093 |
logger.error(f"Minimal launch also failed: {e2}")
|
1094 |
+
print(f"β All launch attempts failed. Error: {e2}")
|