mamba-encoder-swarm_app / modeling_mamba_swarm.py
Debito's picture
Upload 4 files
7aad614 verified
# modeling_mamba_swarm.py - HuggingFace integration for Mamba Swarm
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union
import logging
logger = logging.getLogger(__name__)
class MambaSwarmConfig(PretrainedConfig):
"""Configuration class for MambaSwarm model"""
model_type = "mamba_swarm"
def __init__(
self,
num_encoders=100,
max_mamba_encoders=100,
d_model=768,
vocab_size=50257,
max_sequence_length=2048,
encoder_config=None,
router_config=None,
aggregator_config=None,
**kwargs
):
self.num_encoders = num_encoders
self.max_mamba_encoders = max_mamba_encoders
self.d_model = d_model
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.encoder_config = encoder_config or {}
self.router_config = router_config or {}
self.aggregator_config = aggregator_config or {}
super().__init__(**kwargs)
class MambaSwarmForCausalLM(PreTrainedModel):
"""HuggingFace compatible Mamba Swarm model"""
config_class = MambaSwarmConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Initialize core components
try:
# Try to use the unified swarm engine
from system.mambaSwarm import UnifiedMambaSwarm
self.swarm_engine = UnifiedMambaSwarm(
config=config,
use_pretrained=False # Use native implementation
)
self.num_active_encoders = getattr(self.swarm_engine, 'num_encoders', config.num_encoders)
logger.info("Initialized with UnifiedMambaSwarm")
except ImportError:
try:
# Fallback to native swarm integration
from core.mamba_swarm_integration import MambaEncoderSwarmModel
from core.config import MambaConfig
# Convert config to MambaConfig
mamba_config = MambaConfig(
d_model=config.d_model,
vocab_size=config.vocab_size,
n_layers=8, # Default
d_state=16, # Default
d_conv=4, # Default
bias=False # Default
)
self.swarm_engine = MambaEncoderSwarmModel(
mamba_config,
num_encoders=config.num_encoders
)
self.num_active_encoders = config.num_encoders
logger.info("Initialized with MambaEncoderSwarmModel")
except ImportError as e:
logger.error(f"Could not import swarm components: {e}")
# Create a minimal mock implementation
self.swarm_engine = self._create_mock_engine(config)
self.num_active_encoders = config.num_encoders
logger.warning("Using mock swarm engine")
def _create_mock_engine(self, config):
"""Create a mock engine for testing purposes"""
class MockSwarmEngine:
def __init__(self, config):
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.num_active_encoders = config.num_encoders
def forward(self, input_ids, **kwargs):
# Simple passthrough for testing
embeddings = self.embedding(input_ids)
logits = self.lm_head(embeddings)
return type('MockOutput', (), {'logits': logits, 'past_key_values': None})()
def generate(self, input_ids, max_length=100, **kwargs):
# Simple generation for testing
batch_size, seq_len = input_ids.shape
new_tokens = torch.randint(0, self.config.vocab_size, (batch_size, max_length - seq_len))
return torch.cat([input_ids, new_tokens], dim=1)
def set_active_encoders(self, num):
self.num_active_encoders = min(num, self.config.max_mamba_encoders)
return MockSwarmEngine(config)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs
) -> CausalLMOutputWithPast:
"""Forward pass through the swarm model"""
if input_ids is None:
raise ValueError("input_ids must be provided")
# Get outputs from swarm engine
if hasattr(self.swarm_engine, 'forward'):
outputs = self.swarm_engine.forward(input_ids, **kwargs)
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
else:
# Fallback for engines without forward method
try:
logits = self.swarm_engine(input_ids)
except Exception as e:
logger.error(f"Forward pass failed: {e}")
# Emergency fallback
batch_size, seq_len = input_ids.shape
logits = torch.randn(batch_size, seq_len, self.config.vocab_size)
loss = None
if labels is not None:
# Calculate cross-entropy loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None, # Mamba doesn't use key-value cache
)
def generate(
self,
input_ids: torch.LongTensor,
max_length: int = 100,
temperature: float = 1.0,
top_p: float = 0.9,
do_sample: bool = True,
**kwargs
) -> torch.LongTensor:
"""Generate text using the swarm model"""
try:
if hasattr(self.swarm_engine, 'generate'):
return self.swarm_engine.generate(
input_ids=input_ids,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
**kwargs
)
else:
# Manual generation loop
return self._manual_generate(input_ids, max_length, temperature, top_p, do_sample)
except Exception as e:
logger.error(f"Generation failed: {e}")
# Return input with some random tokens as fallback
batch_size, seq_len = input_ids.shape
new_tokens = torch.randint(0, self.config.vocab_size, (batch_size, max_length - seq_len))
return torch.cat([input_ids, new_tokens], dim=1)
def _manual_generate(self, input_ids, max_length, temperature, top_p, do_sample):
"""Manual generation when swarm engine doesn't have generate method"""
self.eval()
with torch.no_grad():
for _ in range(max_length - input_ids.size(1)):
outputs = self.forward(input_ids)
logits = outputs.logits[:, -1, :] / temperature
if do_sample:
# Apply top-p filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids
def set_active_encoders(self, num_encoders: int):
"""Set the number of active encoders"""
if hasattr(self.swarm_engine, 'set_active_encoders'):
self.swarm_engine.set_active_encoders(num_encoders)
self.num_active_encoders = num_encoders
else:
self.num_active_encoders = min(num_encoders, self.config.max_mamba_encoders)
@classmethod
def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
"""Load model from pretrained weights"""
try:
return super().from_pretrained(model_name_or_path, *model_args, **kwargs)
except Exception as e:
logger.warning(f"Could not load pretrained model: {e}")
# Create with default config if loading fails
config = MambaSwarmConfig()
return cls(config)
def get_num_params(self):
"""Get total number of parameters"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)