Spaces:
Running
Running
| # ============================================================================= | |
| # core/tokenizer.py | |
| # ============================================================================= | |
| from transformers import AutoTokenizer | |
| import torch | |
| from config import MambaConfig | |
| from typing import List, Dict, Union | |
| class MambaTokenizer: | |
| def __init__(self, config: MambaConfig, tokenizer_name: str = "gpt2"): | |
| self.config = config | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
| # Add special tokens if needed | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.vocab_size = len(self.tokenizer) | |
| def encode(self, text: str, max_length: int = None) -> Dict[str, torch.Tensor]: | |
| """Encode text to token ids""" | |
| if max_length is None: | |
| max_length = self.config.max_seq_len | |
| encoded = self.tokenizer( | |
| text, | |
| max_length=max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| return { | |
| "input_ids": encoded["input_ids"], | |
| "attention_mask": encoded["attention_mask"] | |
| } | |
| def encode_batch(self, texts: List[str], max_length: int = None) -> Dict[str, torch.Tensor]: | |
| """Encode batch of texts""" | |
| if max_length is None: | |
| max_length = self.config.max_seq_len | |
| encoded = self.tokenizer( | |
| texts, | |
| max_length=max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| return { | |
| "input_ids": encoded["input_ids"], | |
| "attention_mask": encoded["attention_mask"] | |
| } | |
| def decode(self, token_ids: torch.Tensor, skip_special_tokens: bool = True) -> str: | |
| """Decode token ids to text""" | |
| return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) | |
| def decode_batch(self, token_ids: torch.Tensor, skip_special_tokens: bool = True) -> List[str]: | |
| """Decode batch of token ids""" | |
| return self.tokenizer.batch_decode(token_ids, skip_special_tokens=skip_special_tokens) | |