JunxiongWang commited on
Commit
c5f1c0b
·
verified ·
1 Parent(s): 6d88a20

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +96 -1
README.md CHANGED
@@ -2,4 +2,99 @@
2
  license: apache-2.0
3
  ---
4
 
5
- Train in 30B Byte. Mode size 353M. Table 2 in [MambaByte](https://arxiv.org/abs/2401.13660)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: apache-2.0
3
  ---
4
 
5
+ Train in 30B Byte. Mode size 353M. Table 2 in [MambaByte](https://arxiv.org/abs/2401.13660)
6
+
7
+ To use
8
+
9
+ ```
10
+ import torch
11
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
12
+
13
+ import numpy as np
14
+
15
+ model=MambaLMHeadModel.from_pretrained("JunxiongWang/MambaByte_Code", device='cuda', dtype=torch.float32)
16
+
17
+ text = "import torch"
18
+ text_byte = np.frombuffer(text.encode('utf-8'), dtype=np.uint8)
19
+ input_ids = torch.from_numpy(text_byte[None, :]).long().cuda()
20
+
21
+ sample = model.generate(
22
+ input_ids=input_ids,
23
+ max_length=2048,
24
+ cg=True,
25
+ return_dict_in_generate=True,
26
+ output_scores=True,
27
+ enable_timing=True,
28
+ temperature=1,
29
+ top_k=256,
30
+ top_p=0.9,
31
+ )
32
+
33
+ print(bytes(sample.sequences[0].tolist()).decode('utf-8'))
34
+ ```
35
+
36
+ Output
37
+
38
+ ```
39
+ import torch
40
+ import numpy as np
41
+ import torch.nn.functional as F
42
+ from torch.autograd import Variable
43
+
44
+ from networkx.states import TransientState
45
+
46
+ def extract_data(num_epochs, epochs, is_last_epoch):
47
+
48
+ def get_data(num_features, num_classes):
49
+ data_features = num_features
50
+ data_classes = num_classes
51
+ data_labels = num_epochs
52
+
53
+ if num_features == 0 or num_classes == 0:
54
+ return data_features, data_classes
55
+ if is_last_epoch:
56
+ data_features = num_features
57
+ data_classes = num_classes
58
+ data_labels = num_epochs
59
+ return data_features, data_classes
60
+
61
+ data_features, data_classes = get_data(num_epochs, epochs, is_last_epoch)
62
+ data_labels = num_epochs * 2
63
+ return data_features, data_classes
64
+
65
+
66
+ class NumChannel:
67
+ def __init__(self, x, y, dx=1, dy=1, idx=1, data_size=2, epoch=None):
68
+ """idx is the channel index with data feature in the first epoch.
69
+ x is the channel of the input data.
70
+ y is the element of the input data.
71
+ dx is the element of the data feature of the input data.
72
+ data_size is the size of the element of the data.
73
+ epoch is the channel of the element of the data.
74
+ """
75
+ self.x = x
76
+ self.y = y
77
+ self.dx = dx
78
+ self.data_size = data_size
79
+ self.epoch = epoch
80
+ self.reference_count = 0
81
+ self.data_features = {}
82
+ self.data_classes = {}
83
+
84
+ self._initialize()
85
+ if idx is not None:
86
+ self._start_time = time.time()
87
+
88
+ def _initialize(self):
89
+ """idx is the channel index with data feature in the first epoch.
90
+ x is the channel of the input data.
91
+ y is the element of the input data.
92
+ dx is the element of the data feature of the input data.
93
+ data_size is the size of the element of the data.
94
+ epoch is the channel of the element of the data.
95
+ """
96
+ self.idx = idx
97
+
98
+ def _initialize(self):
99
+ """idx is the channel of the inpu
100
+ ```