Debito commited on
Commit
5115fc5
·
verified ·
1 Parent(s): aec13a2

Delete tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +0 -63
tokenizer.py DELETED
@@ -1,63 +0,0 @@
1
- # =============================================================================
2
- # core/tokenizer.py
3
- # =============================================================================
4
- from transformers import AutoTokenizer
5
- import torch
6
- from config import MambaConfig
7
- from typing import List, Dict, Union
8
-
9
- class MambaTokenizer:
10
- def __init__(self, config: MambaConfig, tokenizer_name: str = "gpt2"):
11
- self.config = config
12
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
13
-
14
- # Add special tokens if needed
15
- if self.tokenizer.pad_token is None:
16
- self.tokenizer.pad_token = self.tokenizer.eos_token
17
-
18
- self.vocab_size = len(self.tokenizer)
19
-
20
- def encode(self, text: str, max_length: int = None) -> Dict[str, torch.Tensor]:
21
- """Encode text to token ids"""
22
- if max_length is None:
23
- max_length = self.config.max_seq_len
24
-
25
- encoded = self.tokenizer(
26
- text,
27
- max_length=max_length,
28
- padding="max_length",
29
- truncation=True,
30
- return_tensors="pt"
31
- )
32
-
33
- return {
34
- "input_ids": encoded["input_ids"],
35
- "attention_mask": encoded["attention_mask"]
36
- }
37
-
38
- def encode_batch(self, texts: List[str], max_length: int = None) -> Dict[str, torch.Tensor]:
39
- """Encode batch of texts"""
40
- if max_length is None:
41
- max_length = self.config.max_seq_len
42
-
43
- encoded = self.tokenizer(
44
- texts,
45
- max_length=max_length,
46
- padding="max_length",
47
- truncation=True,
48
- return_tensors="pt"
49
- )
50
-
51
- return {
52
- "input_ids": encoded["input_ids"],
53
- "attention_mask": encoded["attention_mask"]
54
- }
55
-
56
- def decode(self, token_ids: torch.Tensor, skip_special_tokens: bool = True) -> str:
57
- """Decode token ids to text"""
58
- return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
59
-
60
- def decode_batch(self, token_ids: torch.Tensor, skip_special_tokens: bool = True) -> List[str]:
61
- """Decode batch of token ids"""
62
- return self.tokenizer.batch_decode(token_ids, skip_special_tokens=skip_special_tokens)
63
-