Spaces:
Sleeping
Sleeping
File size: 2,357 Bytes
e6d86b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# =============================================================================
# core/tokenizer.py
# =============================================================================
from transformers import AutoTokenizer
import torch
from config import MambaConfig
from typing import List, Dict, Union
class MambaTokenizer:
def __init__(self, config: MambaConfig, tokenizer_name: str = "gpt2"):
self.config = config
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# Add special tokens if needed
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.vocab_size = len(self.tokenizer)
def encode(self, text: str, max_length: int = None) -> Dict[str, torch.Tensor]:
"""Encode text to token ids"""
if max_length is None:
max_length = self.config.max_seq_len
encoded = self.tokenizer(
text,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
return {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"]
}
def encode_batch(self, texts: List[str], max_length: int = None) -> Dict[str, torch.Tensor]:
"""Encode batch of texts"""
if max_length is None:
max_length = self.config.max_seq_len
encoded = self.tokenizer(
texts,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
return {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"]
}
def decode(self, token_ids: torch.Tensor, skip_special_tokens: bool = True) -> str:
"""Decode token ids to text"""
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
def decode_batch(self, token_ids: torch.Tensor, skip_special_tokens: bool = True) -> List[str]:
"""Decode batch of token ids"""
return self.tokenizer.batch_decode(token_ids, skip_special_tokens=skip_special_tokens)
|