Spaces:
Sleeping
Sleeping
# modeling_mamba_swarm.py - HuggingFace integration for Mamba Swarm | |
from transformers import PreTrainedModel, PretrainedConfig | |
from transformers.modeling_outputs import CausalLMOutputWithPast | |
import torch | |
import torch.nn as nn | |
from typing import Optional, Tuple, Union | |
import logging | |
logger = logging.getLogger(__name__) | |
class MambaSwarmConfig(PretrainedConfig): | |
"""Configuration class for MambaSwarm model""" | |
model_type = "mamba_swarm" | |
def __init__( | |
self, | |
num_encoders=100, | |
max_mamba_encoders=100, | |
d_model=768, | |
vocab_size=50257, | |
max_sequence_length=2048, | |
encoder_config=None, | |
router_config=None, | |
aggregator_config=None, | |
**kwargs | |
): | |
self.num_encoders = num_encoders | |
self.max_mamba_encoders = max_mamba_encoders | |
self.d_model = d_model | |
self.vocab_size = vocab_size | |
self.max_sequence_length = max_sequence_length | |
self.encoder_config = encoder_config or {} | |
self.router_config = router_config or {} | |
self.aggregator_config = aggregator_config or {} | |
super().__init__(**kwargs) | |
class MambaSwarmForCausalLM(PreTrainedModel): | |
"""HuggingFace compatible Mamba Swarm model""" | |
config_class = MambaSwarmConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
# Initialize core components | |
try: | |
# Try to use the unified swarm engine | |
from system.mambaSwarm import UnifiedMambaSwarm | |
self.swarm_engine = UnifiedMambaSwarm( | |
config=config, | |
use_pretrained=False # Use native implementation | |
) | |
self.num_active_encoders = getattr(self.swarm_engine, 'num_encoders', config.num_encoders) | |
logger.info("Initialized with UnifiedMambaSwarm") | |
except ImportError: | |
try: | |
# Fallback to native swarm integration | |
from core.mamba_swarm_integration import MambaEncoderSwarmModel | |
from core.config import MambaConfig | |
# Convert config to MambaConfig | |
mamba_config = MambaConfig( | |
d_model=config.d_model, | |
vocab_size=config.vocab_size, | |
n_layers=8, # Default | |
d_state=16, # Default | |
d_conv=4, # Default | |
bias=False # Default | |
) | |
self.swarm_engine = MambaEncoderSwarmModel( | |
mamba_config, | |
num_encoders=config.num_encoders | |
) | |
self.num_active_encoders = config.num_encoders | |
logger.info("Initialized with MambaEncoderSwarmModel") | |
except ImportError as e: | |
logger.error(f"Could not import swarm components: {e}") | |
# Create a minimal mock implementation | |
self.swarm_engine = self._create_mock_engine(config) | |
self.num_active_encoders = config.num_encoders | |
logger.warning("Using mock swarm engine") | |
def _create_mock_engine(self, config): | |
"""Create a mock engine for testing purposes""" | |
class MockSwarmEngine: | |
def __init__(self, config): | |
self.config = config | |
self.embedding = nn.Embedding(config.vocab_size, config.d_model) | |
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
self.num_active_encoders = config.num_encoders | |
def forward(self, input_ids, **kwargs): | |
# Simple passthrough for testing | |
embeddings = self.embedding(input_ids) | |
logits = self.lm_head(embeddings) | |
return type('MockOutput', (), {'logits': logits, 'past_key_values': None})() | |
def generate(self, input_ids, max_length=100, **kwargs): | |
# Simple generation for testing | |
batch_size, seq_len = input_ids.shape | |
new_tokens = torch.randint(0, self.config.vocab_size, (batch_size, max_length - seq_len)) | |
return torch.cat([input_ids, new_tokens], dim=1) | |
def set_active_encoders(self, num): | |
self.num_active_encoders = min(num, self.config.max_mamba_encoders) | |
return MockSwarmEngine(config) | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
**kwargs | |
) -> CausalLMOutputWithPast: | |
"""Forward pass through the swarm model""" | |
if input_ids is None: | |
raise ValueError("input_ids must be provided") | |
# Get outputs from swarm engine | |
if hasattr(self.swarm_engine, 'forward'): | |
outputs = self.swarm_engine.forward(input_ids, **kwargs) | |
logits = outputs.logits if hasattr(outputs, 'logits') else outputs | |
else: | |
# Fallback for engines without forward method | |
try: | |
logits = self.swarm_engine(input_ids) | |
except Exception as e: | |
logger.error(f"Forward pass failed: {e}") | |
# Emergency fallback | |
batch_size, seq_len = input_ids.shape | |
logits = torch.randn(batch_size, seq_len, self.config.vocab_size) | |
loss = None | |
if labels is not None: | |
# Calculate cross-entropy loss | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
loss_fct = nn.CrossEntropyLoss() | |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
return CausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=None, # Mamba doesn't use key-value cache | |
) | |
def generate( | |
self, | |
input_ids: torch.LongTensor, | |
max_length: int = 100, | |
temperature: float = 1.0, | |
top_p: float = 0.9, | |
do_sample: bool = True, | |
**kwargs | |
) -> torch.LongTensor: | |
"""Generate text using the swarm model""" | |
try: | |
if hasattr(self.swarm_engine, 'generate'): | |
return self.swarm_engine.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=do_sample, | |
**kwargs | |
) | |
else: | |
# Manual generation loop | |
return self._manual_generate(input_ids, max_length, temperature, top_p, do_sample) | |
except Exception as e: | |
logger.error(f"Generation failed: {e}") | |
# Return input with some random tokens as fallback | |
batch_size, seq_len = input_ids.shape | |
new_tokens = torch.randint(0, self.config.vocab_size, (batch_size, max_length - seq_len)) | |
return torch.cat([input_ids, new_tokens], dim=1) | |
def _manual_generate(self, input_ids, max_length, temperature, top_p, do_sample): | |
"""Manual generation when swarm engine doesn't have generate method""" | |
self.eval() | |
with torch.no_grad(): | |
for _ in range(max_length - input_ids.size(1)): | |
outputs = self.forward(input_ids) | |
logits = outputs.logits[:, -1, :] / temperature | |
if do_sample: | |
# Apply top-p filtering | |
if top_p < 1.0: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) | |
sorted_indices_to_remove = cumulative_probs > top_p | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
logits[indices_to_remove] = float('-inf') | |
probs = torch.softmax(logits, dim=-1) | |
next_token = torch.multinomial(probs, num_samples=1) | |
else: | |
next_token = torch.argmax(logits, dim=-1, keepdim=True) | |
input_ids = torch.cat([input_ids, next_token], dim=1) | |
return input_ids | |
def set_active_encoders(self, num_encoders: int): | |
"""Set the number of active encoders""" | |
if hasattr(self.swarm_engine, 'set_active_encoders'): | |
self.swarm_engine.set_active_encoders(num_encoders) | |
self.num_active_encoders = num_encoders | |
else: | |
self.num_active_encoders = min(num_encoders, self.config.max_mamba_encoders) | |
def from_pretrained(cls, model_name_or_path, *model_args, **kwargs): | |
"""Load model from pretrained weights""" | |
try: | |
return super().from_pretrained(model_name_or_path, *model_args, **kwargs) | |
except Exception as e: | |
logger.warning(f"Could not load pretrained model: {e}") | |
# Create with default config if loading fails | |
config = MambaSwarmConfig() | |
return cls(config) | |
def get_num_params(self): | |
"""Get total number of parameters""" | |
return sum(p.numel() for p in self.parameters() if p.requires_grad) | |