Debito commited on
Commit
e6d86b2
·
verified ·
1 Parent(s): 7aad614

Upload 3 files

Browse files
Files changed (3) hide show
  1. configuration_mamba_swarm.py +58 -0
  2. tokenizer.py +63 -0
  3. vocab.json +0 -0
configuration_mamba_swarm.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MambaSwarmConfig(PretrainedConfig):
4
+ model_type = "mamba_swarm"
5
+
6
+ def __init__(
7
+ self,
8
+ num_mamba_encoders=5,
9
+ max_mamba_encoders=1000,
10
+ d_model=768,
11
+ d_state=16,
12
+ d_conv=4,
13
+ expand_factor=2,
14
+ vocab_size=50257,
15
+ max_sequence_length=2048,
16
+ pad_token_id=50256,
17
+ bos_token_id=50256,
18
+ eos_token_id=50256,
19
+ tie_word_embeddings=False,
20
+ use_cache=True,
21
+ gating_config=None,
22
+ routing_config=None,
23
+ **kwargs
24
+ ):
25
+ super().__init__(
26
+ pad_token_id=pad_token_id,
27
+ bos_token_id=bos_token_id,
28
+ eos_token_id=eos_token_id,
29
+ tie_word_embeddings=tie_word_embeddings,
30
+ **kwargs
31
+ )
32
+
33
+ self.num_mamba_encoders = num_mamba_encoders
34
+ self.max_mamba_encoders = max_mamba_encoders
35
+ self.d_model = d_model
36
+ self.d_state = d_state
37
+ self.d_conv = d_conv
38
+ self.expand_factor = expand_factor
39
+ self.vocab_size = vocab_size
40
+ self.max_sequence_length = max_sequence_length
41
+ self.use_cache = use_cache
42
+
43
+ # Default gating configuration
44
+ if gating_config is None:
45
+ gating_config = {
46
+ "gating_type": "learned",
47
+ "top_k": 2,
48
+ "load_balancing_loss_coef": 0.01
49
+ }
50
+ self.gating_config = gating_config
51
+
52
+ # Default routing configuration
53
+ if routing_config is None:
54
+ routing_config = {
55
+ "routing_strategy": "dynamic",
56
+ "aggregation_method": "weighted_average"
57
+ }
58
+ self.routing_config = routing_config
tokenizer.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # core/tokenizer.py
3
+ # =============================================================================
4
+ from transformers import AutoTokenizer
5
+ import torch
6
+ from config import MambaConfig
7
+ from typing import List, Dict, Union
8
+
9
+ class MambaTokenizer:
10
+ def __init__(self, config: MambaConfig, tokenizer_name: str = "gpt2"):
11
+ self.config = config
12
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
13
+
14
+ # Add special tokens if needed
15
+ if self.tokenizer.pad_token is None:
16
+ self.tokenizer.pad_token = self.tokenizer.eos_token
17
+
18
+ self.vocab_size = len(self.tokenizer)
19
+
20
+ def encode(self, text: str, max_length: int = None) -> Dict[str, torch.Tensor]:
21
+ """Encode text to token ids"""
22
+ if max_length is None:
23
+ max_length = self.config.max_seq_len
24
+
25
+ encoded = self.tokenizer(
26
+ text,
27
+ max_length=max_length,
28
+ padding="max_length",
29
+ truncation=True,
30
+ return_tensors="pt"
31
+ )
32
+
33
+ return {
34
+ "input_ids": encoded["input_ids"],
35
+ "attention_mask": encoded["attention_mask"]
36
+ }
37
+
38
+ def encode_batch(self, texts: List[str], max_length: int = None) -> Dict[str, torch.Tensor]:
39
+ """Encode batch of texts"""
40
+ if max_length is None:
41
+ max_length = self.config.max_seq_len
42
+
43
+ encoded = self.tokenizer(
44
+ texts,
45
+ max_length=max_length,
46
+ padding="max_length",
47
+ truncation=True,
48
+ return_tensors="pt"
49
+ )
50
+
51
+ return {
52
+ "input_ids": encoded["input_ids"],
53
+ "attention_mask": encoded["attention_mask"]
54
+ }
55
+
56
+ def decode(self, token_ids: torch.Tensor, skip_special_tokens: bool = True) -> str:
57
+ """Decode token ids to text"""
58
+ return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
59
+
60
+ def decode_batch(self, token_ids: torch.Tensor, skip_special_tokens: bool = True) -> List[str]:
61
+ """Decode batch of token ids"""
62
+ return self.tokenizer.batch_decode(token_ids, skip_special_tokens=skip_special_tokens)
63
+
vocab.json ADDED
The diff for this file is too large to render. See raw diff