Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- app.py +166 -54
- deploy_to_hf.sh +211 -0
- modeling_mamba_swarm.py +235 -0
- upload_to_hf.py +3 -0
app.py
CHANGED
@@ -89,30 +89,118 @@ class MambaSwarmDemo:
|
|
89 |
def _load_real_model(self):
|
90 |
"""Load the actual Mamba Swarm model"""
|
91 |
try:
|
92 |
-
#
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
# Load configuration
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
# Load tokenizer
|
100 |
-
|
101 |
-
|
102 |
-
self.tokenizer.pad_token
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
# Load model with memory optimization
|
106 |
dtype = torch.float16 if self.device.type == "cuda" else torch.float32
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
|
|
116 |
self.model.eval()
|
117 |
self.model_loaded = True
|
118 |
|
@@ -121,9 +209,6 @@ class MambaSwarmDemo:
|
|
121 |
logger.info(f"Model loaded successfully on {self.device}")
|
122 |
logger.info(f"Model parameters: {num_params:,} ({num_params/1e6:.1f}M)")
|
123 |
|
124 |
-
except ImportError as e:
|
125 |
-
logger.error(f"MambaSwarmForCausalLM not available: {e}")
|
126 |
-
raise
|
127 |
except Exception as e:
|
128 |
logger.error(f"Real model loading failed: {e}")
|
129 |
raise
|
@@ -133,12 +218,24 @@ class MambaSwarmDemo:
|
|
133 |
logger.info("Initializing fallback simulation mode")
|
134 |
|
135 |
# Create mock config
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
# Create mock tokenizer
|
144 |
class MockTokenizer:
|
@@ -774,36 +871,51 @@ if __name__ == "__main__":
|
|
774 |
try:
|
775 |
demo = create_production_demo()
|
776 |
|
777 |
-
# Launch with production settings
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
max_threads=10
|
805 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
806 |
except Exception as e:
|
807 |
-
|
808 |
print(f"❌ Demo launch failed: {e}")
|
809 |
print("Please check the logs for more details.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
def _load_real_model(self):
|
90 |
"""Load the actual Mamba Swarm model"""
|
91 |
try:
|
92 |
+
# Try multiple import paths for the model
|
93 |
+
model_class = None
|
94 |
+
|
95 |
+
# Try importing from different locations
|
96 |
+
try:
|
97 |
+
from modeling_mamba_swarm import MambaSwarmForCausalLM
|
98 |
+
model_class = MambaSwarmForCausalLM
|
99 |
+
logger.info("Loaded MambaSwarmForCausalLM from modeling_mamba_swarm")
|
100 |
+
except ImportError:
|
101 |
+
try:
|
102 |
+
from upload_to_hf import MambaSwarmForCausalLM
|
103 |
+
model_class = MambaSwarmForCausalLM
|
104 |
+
logger.info("Loaded MambaSwarmForCausalLM from upload_to_hf")
|
105 |
+
except ImportError:
|
106 |
+
try:
|
107 |
+
from core.mamba_swarm_integration import MambaEncoderSwarmModel
|
108 |
+
model_class = MambaEncoderSwarmModel
|
109 |
+
logger.info("Loaded MambaEncoderSwarmModel from core.mamba_swarm_integration")
|
110 |
+
except ImportError:
|
111 |
+
try:
|
112 |
+
from system.mambaSwarm import UnifiedMambaSwarm
|
113 |
+
# Use the unified swarm in native mode
|
114 |
+
swarm = UnifiedMambaSwarm(use_pretrained=False)
|
115 |
+
if hasattr(swarm, 'native_swarm_model') and swarm.native_swarm_model:
|
116 |
+
self.model = swarm.native_swarm_model
|
117 |
+
self.model_loaded = True
|
118 |
+
logger.info("Loaded native swarm model from UnifiedMambaSwarm")
|
119 |
+
return
|
120 |
+
else:
|
121 |
+
raise ImportError("No native swarm model available")
|
122 |
+
except ImportError as e:
|
123 |
+
logger.error(f"All model imports failed: {e}")
|
124 |
+
raise ImportError("No compatible Mamba Swarm model found")
|
125 |
+
|
126 |
+
if model_class is None:
|
127 |
+
raise ImportError("No model class available")
|
128 |
|
129 |
# Load configuration
|
130 |
+
try:
|
131 |
+
self.config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True)
|
132 |
+
logger.info(f"Loaded config: {self.config.__class__.__name__}")
|
133 |
+
except Exception as e:
|
134 |
+
logger.warning(f"Could not load config from {self.model_path}: {e}")
|
135 |
+
# Create a default config using our MambaSwarmConfig
|
136 |
+
try:
|
137 |
+
from modeling_mamba_swarm import MambaSwarmConfig
|
138 |
+
self.config = MambaSwarmConfig(
|
139 |
+
num_encoders=8,
|
140 |
+
max_mamba_encoders=100,
|
141 |
+
d_model=768,
|
142 |
+
vocab_size=50257,
|
143 |
+
max_sequence_length=2048
|
144 |
+
)
|
145 |
+
logger.info("Using default MambaSwarmConfig")
|
146 |
+
except ImportError:
|
147 |
+
# Final fallback to basic config
|
148 |
+
from core.config import MambaConfig
|
149 |
+
self.config = MambaConfig()
|
150 |
+
# Add swarm-specific attributes
|
151 |
+
self.config.num_encoders = 8
|
152 |
+
self.config.max_mamba_encoders = 100
|
153 |
+
self.config.max_sequence_length = 2048
|
154 |
+
logger.info("Using default MambaConfig with swarm attributes")
|
155 |
|
156 |
# Load tokenizer
|
157 |
+
try:
|
158 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
159 |
+
if self.tokenizer.pad_token is None:
|
160 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
161 |
+
logger.info("Tokenizer loaded successfully")
|
162 |
+
except Exception as e:
|
163 |
+
logger.warning(f"Could not load tokenizer: {e}")
|
164 |
+
# Use a simple fallback tokenizer
|
165 |
+
from transformers import GPT2Tokenizer
|
166 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
167 |
+
if self.tokenizer.pad_token is None:
|
168 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
169 |
+
logger.info("Using fallback GPT2 tokenizer")
|
170 |
|
171 |
# Load model with memory optimization
|
172 |
dtype = torch.float16 if self.device.type == "cuda" else torch.float32
|
173 |
|
174 |
+
if model_class == MambaEncoderSwarmModel:
|
175 |
+
# Native integration model - create with MambaConfig
|
176 |
+
from core.config import MambaConfig
|
177 |
+
if not hasattr(self, 'config') or not isinstance(self.config, MambaConfig):
|
178 |
+
mamba_config = MambaConfig(
|
179 |
+
d_model=getattr(self.config, 'd_model', 768),
|
180 |
+
vocab_size=getattr(self.config, 'vocab_size', 50257),
|
181 |
+
n_layers=8,
|
182 |
+
d_state=16,
|
183 |
+
d_conv=4,
|
184 |
+
bias=False
|
185 |
+
)
|
186 |
+
self.model = model_class(mamba_config, num_encoders=getattr(self.config, 'num_encoders', 8))
|
187 |
+
else:
|
188 |
+
self.model = model_class(self.config, num_encoders=getattr(self.config, 'num_encoders', 8))
|
189 |
+
else:
|
190 |
+
# HuggingFace-style model or our new MambaSwarmForCausalLM
|
191 |
+
if hasattr(model_class, 'from_pretrained') and os.path.exists(self.model_path):
|
192 |
+
self.model = model_class.from_pretrained(
|
193 |
+
self.model_path,
|
194 |
+
config=self.config,
|
195 |
+
torch_dtype=dtype,
|
196 |
+
trust_remote_code=True,
|
197 |
+
low_cpu_mem_usage=True
|
198 |
+
)
|
199 |
+
else:
|
200 |
+
# Create with config only
|
201 |
+
self.model = model_class(self.config)
|
202 |
|
203 |
+
self.model.to(self.device)
|
204 |
self.model.eval()
|
205 |
self.model_loaded = True
|
206 |
|
|
|
209 |
logger.info(f"Model loaded successfully on {self.device}")
|
210 |
logger.info(f"Model parameters: {num_params:,} ({num_params/1e6:.1f}M)")
|
211 |
|
|
|
|
|
|
|
212 |
except Exception as e:
|
213 |
logger.error(f"Real model loading failed: {e}")
|
214 |
raise
|
|
|
218 |
logger.info("Initializing fallback simulation mode")
|
219 |
|
220 |
# Create mock config
|
221 |
+
try:
|
222 |
+
from modeling_mamba_swarm import MambaSwarmConfig
|
223 |
+
self.config = MambaSwarmConfig(
|
224 |
+
num_encoders=8,
|
225 |
+
max_mamba_encoders=100,
|
226 |
+
d_model=768,
|
227 |
+
vocab_size=50257,
|
228 |
+
max_sequence_length=2048
|
229 |
+
)
|
230 |
+
except ImportError:
|
231 |
+
# Fallback mock config
|
232 |
+
self.config = type('MockConfig', (), {
|
233 |
+
'max_mamba_encoders': 100,
|
234 |
+
'num_encoders': 8,
|
235 |
+
'd_model': 768,
|
236 |
+
'vocab_size': 50257,
|
237 |
+
'max_sequence_length': 2048
|
238 |
+
})()
|
239 |
|
240 |
# Create mock tokenizer
|
241 |
class MockTokenizer:
|
|
|
871 |
try:
|
872 |
demo = create_production_demo()
|
873 |
|
874 |
+
# Launch with production settings - compatible with different Gradio versions
|
875 |
+
launch_kwargs = {
|
876 |
+
"server_name": "0.0.0.0",
|
877 |
+
"server_port": 7860,
|
878 |
+
"share": False, # Set to True for public sharing
|
879 |
+
"debug": False,
|
880 |
+
"show_error": True,
|
881 |
+
"quiet": False,
|
882 |
+
}
|
883 |
+
|
884 |
+
# Add optional parameters if supported
|
885 |
+
try:
|
886 |
+
# Test if these parameters are supported in this Gradio version
|
887 |
+
import gradio as gr
|
888 |
+
import inspect
|
889 |
+
launch_signature = inspect.signature(gr.Blocks.launch)
|
890 |
+
|
891 |
+
# Add parameters if supported
|
892 |
+
if 'favicon_path' in launch_signature.parameters:
|
893 |
+
launch_kwargs['favicon_path'] = None
|
894 |
+
if 'ssl_verify' in launch_signature.parameters:
|
895 |
+
launch_kwargs['ssl_verify'] = False
|
896 |
+
if 'show_tips' in launch_signature.parameters:
|
897 |
+
launch_kwargs['show_tips'] = True
|
898 |
+
if 'enable_queue' in launch_signature.parameters:
|
899 |
+
launch_kwargs['enable_queue'] = True
|
900 |
+
if 'max_threads' in launch_signature.parameters:
|
901 |
+
launch_kwargs['max_threads'] = 10
|
902 |
+
|
903 |
+
except Exception as e:
|
904 |
+
logger.warning(f"Could not detect Gradio parameters: {e}")
|
905 |
+
|
906 |
+
# Launch with detected parameters
|
907 |
+
logger.info(f"Launching with parameters: {list(launch_kwargs.keys())}")
|
908 |
+
demo.launch(**launch_kwargs)
|
909 |
+
|
910 |
except Exception as e:
|
911 |
+
logger.error(f"Failed to launch demo: {e}")
|
912 |
print(f"❌ Demo launch failed: {e}")
|
913 |
print("Please check the logs for more details.")
|
914 |
+
|
915 |
+
# Try minimal launch as last resort
|
916 |
+
try:
|
917 |
+
logger.info("Attempting minimal launch...")
|
918 |
+
demo.launch(share=False, debug=False)
|
919 |
+
except Exception as e2:
|
920 |
+
logger.error(f"Minimal launch also failed: {e2}")
|
921 |
+
print(f"❌ All launch attempts failed. Error: {e2}")
|
deploy_to_hf.sh
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# deploy_to_hf.sh - Complete deployment script
|
3 |
+
|
4 |
+
echo "🚀 Deploying Mamba Swarm to HuggingFace..."
|
5 |
+
|
6 |
+
# Set your HuggingFace username
|
7 |
+
HF_USERNAME="your-username" # Replace with your actual username
|
8 |
+
|
9 |
+
# Step 1: Create repositories on HuggingFace
|
10 |
+
echo "📦 Creating repositories..."
|
11 |
+
huggingface-cli repo create mamba-swarm-model --type model
|
12 |
+
huggingface-cli repo create mamba-swarm-weights --type model
|
13 |
+
huggingface-cli repo create mamba-swarm-demo --type space --space_sdk gradio
|
14 |
+
|
15 |
+
# Step 2: Clone repositories locally
|
16 |
+
echo "📁 Cloning repositories..."
|
17 |
+
mkdir -p hf_repos
|
18 |
+
cd hf_repos
|
19 |
+
|
20 |
+
git clone https://huggingface.co/$HF_USERNAME/mamba-swarm-model
|
21 |
+
git clone https://huggingface.co/$HF_USERNAME/mamba-swarm-weights
|
22 |
+
git clone https://huggingface.co/$HF_USERNAME/mamba-swarm-demo
|
23 |
+
|
24 |
+
# Step 3: Prepare model repository
|
25 |
+
echo "🔧 Preparing model code..."
|
26 |
+
cd mamba-swarm-model
|
27 |
+
|
28 |
+
# Copy your mamba_swarm code
|
29 |
+
cp -r ../../mamba_swarm .
|
30 |
+
|
31 |
+
# Create README.md
|
32 |
+
cat > README.md << 'EOF'
|
33 |
+
---
|
34 |
+
license: apache-2.0
|
35 |
+
language:
|
36 |
+
- en
|
37 |
+
pipeline_tag: text-generation
|
38 |
+
tags:
|
39 |
+
- mamba
|
40 |
+
- swarm
|
41 |
+
- routing
|
42 |
+
- language-model
|
43 |
+
library_name: transformers
|
44 |
+
---
|
45 |
+
|
46 |
+
# Mamba Swarm: Dynamic Routing Language Model
|
47 |
+
|
48 |
+
A novel architecture combining 100 specialized Mamba encoders with dynamic routing and aggregation for efficient language modeling.
|
49 |
+
|
50 |
+
## Quick Start
|
51 |
+
|
52 |
+
```python
|
53 |
+
from transformers import AutoModel, AutoTokenizer
|
54 |
+
|
55 |
+
# Load model and tokenizer
|
56 |
+
model = AutoModel.from_pretrained("$HF_USERNAME/mamba-swarm-model")
|
57 |
+
tokenizer = AutoTokenizer.from_pretrained("$HF_USERNAME/mamba-swarm-model")
|
58 |
+
|
59 |
+
# Generate text
|
60 |
+
input_text = "Explain quantum computing"
|
61 |
+
inputs = tokenizer(input_text, return_tensors="pt")
|
62 |
+
outputs = model.generate(**inputs, max_length=100)
|
63 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
64 |
+
print(response)
|
65 |
+
```
|
66 |
+
|
67 |
+
## Architecture
|
68 |
+
|
69 |
+
- **100 Mamba Encoders**: Domain-specialized experts
|
70 |
+
- **Dynamic Router**: Content-aware encoder selection
|
71 |
+
- **Aggregation Layer**: Intelligent output combination
|
72 |
+
- **Mamba Decoder**: Coherent response generation
|
73 |
+
|
74 |
+
## Demo
|
75 |
+
|
76 |
+
Try the interactive demo: [Mamba Swarm Demo](https://huggingface.co/spaces/$HF_USERNAME/mamba-swarm-demo)
|
77 |
+
EOF
|
78 |
+
|
79 |
+
# Create requirements.txt
|
80 |
+
cat > requirements.txt << 'EOF'
|
81 |
+
torch>=2.0.0
|
82 |
+
transformers>=4.35.0
|
83 |
+
mamba-ssm>=1.2.0
|
84 |
+
causal-conv1d>=1.2.0
|
85 |
+
numpy>=1.21.0
|
86 |
+
scipy>=1.7.0
|
87 |
+
triton>=2.0.0
|
88 |
+
einops>=0.6.1
|
89 |
+
packaging>=20.0
|
90 |
+
accelerate>=0.20.0
|
91 |
+
EOF
|
92 |
+
|
93 |
+
# Create config.json
|
94 |
+
cat > config.json << 'EOF'
|
95 |
+
{
|
96 |
+
"model_type": "mamba_swarm",
|
97 |
+
"architectures": ["MambaSwarmForCausalLM"],
|
98 |
+
"num_encoders": 100,
|
99 |
+
"encoder_config": {
|
100 |
+
"d_model": 768,
|
101 |
+
"n_layer": 24,
|
102 |
+
"vocab_size": 50280,
|
103 |
+
"ssm_cfg": {},
|
104 |
+
"rms_norm": true,
|
105 |
+
"residual_in_fp32": true,
|
106 |
+
"fused_add_norm": true
|
107 |
+
},
|
108 |
+
"router_config": {
|
109 |
+
"top_k": 10,
|
110 |
+
"routing_strategy": "content_based"
|
111 |
+
},
|
112 |
+
"aggregator_config": {
|
113 |
+
"method": "weighted_sum",
|
114 |
+
"attention_heads": 8
|
115 |
+
},
|
116 |
+
"torch_dtype": "float16",
|
117 |
+
"use_cache": true
|
118 |
+
}
|
119 |
+
EOF
|
120 |
+
|
121 |
+
# Commit and push model code
|
122 |
+
git add .
|
123 |
+
git commit -m "Initial upload: Mamba Swarm model code"
|
124 |
+
git push
|
125 |
+
|
126 |
+
echo "✅ Model code uploaded!"
|
127 |
+
|
128 |
+
# Step 4: Prepare Gradio app
|
129 |
+
echo "🎨 Preparing Gradio demo..."
|
130 |
+
cd ../mamba-swarm-demo
|
131 |
+
|
132 |
+
# Copy the app.py file we created
|
133 |
+
cp ../../gradio_app.py app.py
|
134 |
+
|
135 |
+
# Update the model name in app.py
|
136 |
+
sed -i "s/your-username/$HF_USERNAME/g" app.py
|
137 |
+
|
138 |
+
# Create requirements.txt for the Space
|
139 |
+
cat > requirements.txt << 'EOF'
|
140 |
+
gradio>=4.0.0
|
141 |
+
torch>=2.0.0
|
142 |
+
transformers>=4.35.0
|
143 |
+
numpy>=1.21.0
|
144 |
+
mamba-ssm>=1.2.0
|
145 |
+
causal-conv1d>=1.2.0
|
146 |
+
EOF
|
147 |
+
|
148 |
+
# Create README.md for the Space
|
149 |
+
cat > README.md << 'EOF'
|
150 |
+
---
|
151 |
+
title: Mamba Swarm Demo
|
152 |
+
emoji: 🐍
|
153 |
+
colorFrom: green
|
154 |
+
colorTo: blue
|
155 |
+
sdk: gradio
|
156 |
+
sdk_version: 4.8.0
|
157 |
+
app_file: app.py
|
158 |
+
pinned: false
|
159 |
+
license: apache-2.0
|
160 |
+
---
|
161 |
+
|
162 |
+
# Mamba Swarm Interactive Demo
|
163 |
+
|
164 |
+
Experience the power of 100 specialized Mamba encoders with intelligent routing!
|
165 |
+
|
166 |
+
This demo showcases how our Mamba Swarm model dynamically selects the most relevant encoders for different types of queries, providing specialized responses across various domains.
|
167 |
+
|
168 |
+
## Features
|
169 |
+
|
170 |
+
- **Dynamic Routing**: Watch as the model selects optimal encoders
|
171 |
+
- **Domain Specialization**: See how different domains are handled
|
172 |
+
- **Interactive Interface**: Experiment with different parameters
|
173 |
+
- **Real-time Visualization**: View routing decisions and confidence scores
|
174 |
+
|
175 |
+
## Architecture
|
176 |
+
|
177 |
+
The Mamba Swarm consists of:
|
178 |
+
- 100 specialized Mamba encoders
|
179 |
+
- Intelligent content-based routing
|
180 |
+
- Advanced aggregation mechanisms
|
181 |
+
- Optimized inference pipeline
|
182 |
+
|
183 |
+
Try it out with different types of questions to see the routing in action!
|
184 |
+
EOF
|
185 |
+
|
186 |
+
# Commit and push Gradio app
|
187 |
+
git add .
|
188 |
+
git commit -m "Initial upload: Mamba Swarm Gradio demo"
|
189 |
+
git push
|
190 |
+
|
191 |
+
echo "✅ Gradio demo uploaded!"
|
192 |
+
|
193 |
+
# Step 5: Instructions for weights (when available)
|
194 |
+
echo "📋 Next steps for model weights:"
|
195 |
+
echo ""
|
196 |
+
echo "When you have trained model weights, upload them with:"
|
197 |
+
echo "cd hf_repos/mamba-swarm-weights"
|
198 |
+
echo "# Copy your checkpoint files here"
|
199 |
+
echo "git add ."
|
200 |
+
echo "git commit -m 'Upload trained model weights'"
|
201 |
+
echo "git push"
|
202 |
+
echo ""
|
203 |
+
echo "🎉 Deployment complete!"
|
204 |
+
echo ""
|
205 |
+
echo "Your repositories:"
|
206 |
+
echo "- Model: https://huggingface.co/$HF_USERNAME/mamba-swarm-model"
|
207 |
+
echo "- Weights: https://huggingface.co/$HF_USERNAME/mamba-swarm-weights"
|
208 |
+
echo "- Demo: https://huggingface.co/$HF_USERNAME/mamba-swarm-demo"
|
209 |
+
echo ""
|
210 |
+
echo "The Gradio demo will be available at:"
|
211 |
+
echo "https://huggingface.co/spaces/$HF_USERNAME/mamba-swarm-demo"
|
modeling_mamba_swarm.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modeling_mamba_swarm.py - HuggingFace integration for Mamba Swarm
|
2 |
+
|
3 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
4 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from typing import Optional, Tuple, Union
|
8 |
+
import logging
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
class MambaSwarmConfig(PretrainedConfig):
|
13 |
+
"""Configuration class for MambaSwarm model"""
|
14 |
+
model_type = "mamba_swarm"
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
num_encoders=100,
|
19 |
+
max_mamba_encoders=100,
|
20 |
+
d_model=768,
|
21 |
+
vocab_size=50257,
|
22 |
+
max_sequence_length=2048,
|
23 |
+
encoder_config=None,
|
24 |
+
router_config=None,
|
25 |
+
aggregator_config=None,
|
26 |
+
**kwargs
|
27 |
+
):
|
28 |
+
self.num_encoders = num_encoders
|
29 |
+
self.max_mamba_encoders = max_mamba_encoders
|
30 |
+
self.d_model = d_model
|
31 |
+
self.vocab_size = vocab_size
|
32 |
+
self.max_sequence_length = max_sequence_length
|
33 |
+
self.encoder_config = encoder_config or {}
|
34 |
+
self.router_config = router_config or {}
|
35 |
+
self.aggregator_config = aggregator_config or {}
|
36 |
+
super().__init__(**kwargs)
|
37 |
+
|
38 |
+
class MambaSwarmForCausalLM(PreTrainedModel):
|
39 |
+
"""HuggingFace compatible Mamba Swarm model"""
|
40 |
+
config_class = MambaSwarmConfig
|
41 |
+
|
42 |
+
def __init__(self, config):
|
43 |
+
super().__init__(config)
|
44 |
+
self.config = config
|
45 |
+
|
46 |
+
# Initialize core components
|
47 |
+
try:
|
48 |
+
# Try to use the unified swarm engine
|
49 |
+
from system.mambaSwarm import UnifiedMambaSwarm
|
50 |
+
self.swarm_engine = UnifiedMambaSwarm(
|
51 |
+
config=config,
|
52 |
+
use_pretrained=False # Use native implementation
|
53 |
+
)
|
54 |
+
self.num_active_encoders = getattr(self.swarm_engine, 'num_encoders', config.num_encoders)
|
55 |
+
logger.info("Initialized with UnifiedMambaSwarm")
|
56 |
+
|
57 |
+
except ImportError:
|
58 |
+
try:
|
59 |
+
# Fallback to native swarm integration
|
60 |
+
from core.mamba_swarm_integration import MambaEncoderSwarmModel
|
61 |
+
from core.config import MambaConfig
|
62 |
+
|
63 |
+
# Convert config to MambaConfig
|
64 |
+
mamba_config = MambaConfig(
|
65 |
+
d_model=config.d_model,
|
66 |
+
vocab_size=config.vocab_size,
|
67 |
+
n_layers=8, # Default
|
68 |
+
d_state=16, # Default
|
69 |
+
d_conv=4, # Default
|
70 |
+
bias=False # Default
|
71 |
+
)
|
72 |
+
|
73 |
+
self.swarm_engine = MambaEncoderSwarmModel(
|
74 |
+
mamba_config,
|
75 |
+
num_encoders=config.num_encoders
|
76 |
+
)
|
77 |
+
self.num_active_encoders = config.num_encoders
|
78 |
+
logger.info("Initialized with MambaEncoderSwarmModel")
|
79 |
+
|
80 |
+
except ImportError as e:
|
81 |
+
logger.error(f"Could not import swarm components: {e}")
|
82 |
+
# Create a minimal mock implementation
|
83 |
+
self.swarm_engine = self._create_mock_engine(config)
|
84 |
+
self.num_active_encoders = config.num_encoders
|
85 |
+
logger.warning("Using mock swarm engine")
|
86 |
+
|
87 |
+
def _create_mock_engine(self, config):
|
88 |
+
"""Create a mock engine for testing purposes"""
|
89 |
+
class MockSwarmEngine:
|
90 |
+
def __init__(self, config):
|
91 |
+
self.config = config
|
92 |
+
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
|
93 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
94 |
+
self.num_active_encoders = config.num_encoders
|
95 |
+
|
96 |
+
def forward(self, input_ids, **kwargs):
|
97 |
+
# Simple passthrough for testing
|
98 |
+
embeddings = self.embedding(input_ids)
|
99 |
+
logits = self.lm_head(embeddings)
|
100 |
+
return type('MockOutput', (), {'logits': logits, 'past_key_values': None})()
|
101 |
+
|
102 |
+
def generate(self, input_ids, max_length=100, **kwargs):
|
103 |
+
# Simple generation for testing
|
104 |
+
batch_size, seq_len = input_ids.shape
|
105 |
+
new_tokens = torch.randint(0, self.config.vocab_size, (batch_size, max_length - seq_len))
|
106 |
+
return torch.cat([input_ids, new_tokens], dim=1)
|
107 |
+
|
108 |
+
def set_active_encoders(self, num):
|
109 |
+
self.num_active_encoders = min(num, self.config.max_mamba_encoders)
|
110 |
+
|
111 |
+
return MockSwarmEngine(config)
|
112 |
+
|
113 |
+
def forward(
|
114 |
+
self,
|
115 |
+
input_ids: Optional[torch.LongTensor] = None,
|
116 |
+
attention_mask: Optional[torch.Tensor] = None,
|
117 |
+
labels: Optional[torch.LongTensor] = None,
|
118 |
+
**kwargs
|
119 |
+
) -> CausalLMOutputWithPast:
|
120 |
+
"""Forward pass through the swarm model"""
|
121 |
+
|
122 |
+
if input_ids is None:
|
123 |
+
raise ValueError("input_ids must be provided")
|
124 |
+
|
125 |
+
# Get outputs from swarm engine
|
126 |
+
if hasattr(self.swarm_engine, 'forward'):
|
127 |
+
outputs = self.swarm_engine.forward(input_ids, **kwargs)
|
128 |
+
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
|
129 |
+
else:
|
130 |
+
# Fallback for engines without forward method
|
131 |
+
try:
|
132 |
+
logits = self.swarm_engine(input_ids)
|
133 |
+
except Exception as e:
|
134 |
+
logger.error(f"Forward pass failed: {e}")
|
135 |
+
# Emergency fallback
|
136 |
+
batch_size, seq_len = input_ids.shape
|
137 |
+
logits = torch.randn(batch_size, seq_len, self.config.vocab_size)
|
138 |
+
|
139 |
+
loss = None
|
140 |
+
if labels is not None:
|
141 |
+
# Calculate cross-entropy loss
|
142 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
143 |
+
shift_labels = labels[..., 1:].contiguous()
|
144 |
+
loss_fct = nn.CrossEntropyLoss()
|
145 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
146 |
+
|
147 |
+
return CausalLMOutputWithPast(
|
148 |
+
loss=loss,
|
149 |
+
logits=logits,
|
150 |
+
past_key_values=None, # Mamba doesn't use key-value cache
|
151 |
+
)
|
152 |
+
|
153 |
+
def generate(
|
154 |
+
self,
|
155 |
+
input_ids: torch.LongTensor,
|
156 |
+
max_length: int = 100,
|
157 |
+
temperature: float = 1.0,
|
158 |
+
top_p: float = 0.9,
|
159 |
+
do_sample: bool = True,
|
160 |
+
**kwargs
|
161 |
+
) -> torch.LongTensor:
|
162 |
+
"""Generate text using the swarm model"""
|
163 |
+
|
164 |
+
try:
|
165 |
+
if hasattr(self.swarm_engine, 'generate'):
|
166 |
+
return self.swarm_engine.generate(
|
167 |
+
input_ids=input_ids,
|
168 |
+
max_length=max_length,
|
169 |
+
temperature=temperature,
|
170 |
+
top_p=top_p,
|
171 |
+
do_sample=do_sample,
|
172 |
+
**kwargs
|
173 |
+
)
|
174 |
+
else:
|
175 |
+
# Manual generation loop
|
176 |
+
return self._manual_generate(input_ids, max_length, temperature, top_p, do_sample)
|
177 |
+
|
178 |
+
except Exception as e:
|
179 |
+
logger.error(f"Generation failed: {e}")
|
180 |
+
# Return input with some random tokens as fallback
|
181 |
+
batch_size, seq_len = input_ids.shape
|
182 |
+
new_tokens = torch.randint(0, self.config.vocab_size, (batch_size, max_length - seq_len))
|
183 |
+
return torch.cat([input_ids, new_tokens], dim=1)
|
184 |
+
|
185 |
+
def _manual_generate(self, input_ids, max_length, temperature, top_p, do_sample):
|
186 |
+
"""Manual generation when swarm engine doesn't have generate method"""
|
187 |
+
self.eval()
|
188 |
+
|
189 |
+
with torch.no_grad():
|
190 |
+
for _ in range(max_length - input_ids.size(1)):
|
191 |
+
outputs = self.forward(input_ids)
|
192 |
+
logits = outputs.logits[:, -1, :] / temperature
|
193 |
+
|
194 |
+
if do_sample:
|
195 |
+
# Apply top-p filtering
|
196 |
+
if top_p < 1.0:
|
197 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
198 |
+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
199 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
200 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
201 |
+
sorted_indices_to_remove[..., 0] = 0
|
202 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
203 |
+
logits[indices_to_remove] = float('-inf')
|
204 |
+
|
205 |
+
probs = torch.softmax(logits, dim=-1)
|
206 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
207 |
+
else:
|
208 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
209 |
+
|
210 |
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
211 |
+
|
212 |
+
return input_ids
|
213 |
+
|
214 |
+
def set_active_encoders(self, num_encoders: int):
|
215 |
+
"""Set the number of active encoders"""
|
216 |
+
if hasattr(self.swarm_engine, 'set_active_encoders'):
|
217 |
+
self.swarm_engine.set_active_encoders(num_encoders)
|
218 |
+
self.num_active_encoders = num_encoders
|
219 |
+
else:
|
220 |
+
self.num_active_encoders = min(num_encoders, self.config.max_mamba_encoders)
|
221 |
+
|
222 |
+
@classmethod
|
223 |
+
def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
|
224 |
+
"""Load model from pretrained weights"""
|
225 |
+
try:
|
226 |
+
return super().from_pretrained(model_name_or_path, *model_args, **kwargs)
|
227 |
+
except Exception as e:
|
228 |
+
logger.warning(f"Could not load pretrained model: {e}")
|
229 |
+
# Create with default config if loading fails
|
230 |
+
config = MambaSwarmConfig()
|
231 |
+
return cls(config)
|
232 |
+
|
233 |
+
def get_num_params(self):
|
234 |
+
"""Get total number of parameters"""
|
235 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
upload_to_hf.py
CHANGED
@@ -5,6 +5,9 @@ import shutil
|
|
5 |
from huggingface_hub import HfApi, upload_folder
|
6 |
import json
|
7 |
|
|
|
|
|
|
|
8 |
def prepare_model_repo():
|
9 |
"""Prepare model repository structure for HuggingFace"""
|
10 |
|
|
|
5 |
from huggingface_hub import HfApi, upload_folder
|
6 |
import json
|
7 |
|
8 |
+
# Import the actual model classes
|
9 |
+
from modeling_mamba_swarm import MambaSwarmForCausalLM, MambaSwarmConfig
|
10 |
+
|
11 |
def prepare_model_repo():
|
12 |
"""Prepare model repository structure for HuggingFace"""
|
13 |
|