|
--- |
|
license: apache-2.0 |
|
metrics: |
|
- perplexity |
|
pipeline_tag: text-generation |
|
--- |
|
|
|
Train in 30B Byte. Mode size 353M. Table 2 in [MambaByte](https://arxiv.org/abs/2401.13660) |
|
|
|
To use |
|
|
|
``` |
|
import torch |
|
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel |
|
|
|
import numpy as np |
|
|
|
model=MambaLMHeadModel.from_pretrained("JunxiongWang/MambaByte_Code", device='cuda', dtype=torch.float32) |
|
|
|
text = "import torch" |
|
text_byte = np.frombuffer(text.encode('utf-8'), dtype=np.uint8) |
|
input_ids = torch.from_numpy(text_byte[None, :].copy()).long().cuda() |
|
|
|
sample = model.generate( |
|
input_ids=input_ids, |
|
max_length=2048, |
|
cg=True, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
enable_timing=True, |
|
temperature=1, |
|
top_k=256, |
|
top_p=0.9, |
|
) |
|
|
|
print(bytes(sample.sequences[0].tolist()).decode('utf-8')) |
|
``` |
|
|
|
Output |
|
|
|
``` |
|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from torch.autograd import Variable |
|
|
|
from networkx.states import TransientState |
|
|
|
def extract_data(num_epochs, epochs, is_last_epoch): |
|
|
|
def get_data(num_features, num_classes): |
|
data_features = num_features |
|
data_classes = num_classes |
|
data_labels = num_epochs |
|
|
|
if num_features == 0 or num_classes == 0: |
|
return data_features, data_classes |
|
if is_last_epoch: |
|
data_features = num_features |
|
data_classes = num_classes |
|
data_labels = num_epochs |
|
return data_features, data_classes |
|
|
|
data_features, data_classes = get_data(num_epochs, epochs, is_last_epoch) |
|
data_labels = num_epochs * 2 |
|
return data_features, data_classes |
|
|
|
|
|
class NumChannel: |
|
def __init__(self, x, y, dx=1, dy=1, idx=1, data_size=2, epoch=None): |
|
"""idx is the channel index with data feature in the first epoch. |
|
x is the channel of the input data. |
|
y is the element of the input data. |
|
dx is the element of the data feature of the input data. |
|
data_size is the size of the element of the data. |
|
epoch is the channel of the element of the data. |
|
""" |
|
self.x = x |
|
self.y = y |
|
self.dx = dx |
|
self.data_size = data_size |
|
self.epoch = epoch |
|
self.reference_count = 0 |
|
self.data_features = {} |
|
self.data_classes = {} |
|
|
|
self._initialize() |
|
if idx is not None: |
|
self._start_time = time.time() |
|
|
|
def _initialize(self): |
|
"""idx is the channel index with data feature in the first epoch. |
|
x is the channel of the input data. |
|
y is the element of the input data. |
|
dx is the element of the data feature of the input data. |
|
data_size is the size of the element of the data. |
|
epoch is the channel of the element of the data. |
|
""" |
|
self.idx = idx |
|
``` |