caesar-one commited on
Commit
b5eb6ef
·
verified ·
1 Parent(s): 5d7c5ae

Upload 4 files

Browse files
Files changed (4) hide show
  1. __init__.py +0 -0
  2. artifact.metadata +60 -0
  3. colbert_configuration.py +412 -0
  4. tokenization_utils.py +191 -0
__init__.py ADDED
File without changes
artifact.metadata ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "query_token_id": "[unused0]",
3
+ "doc_token_id": "[unused1]",
4
+ "query_token": "[Q]",
5
+ "doc_token": "[D]",
6
+ "ncells": null,
7
+ "centroid_score_threshold": null,
8
+ "ndocs": null,
9
+ "load_index_with_mmap": false,
10
+ "index_path": null,
11
+ "index_bsize": 64,
12
+ "nbits": 1,
13
+ "kmeans_niters": 4,
14
+ "resume": false,
15
+ "similarity": "cosine",
16
+ "bsize": 2,
17
+ "accumsteps": 2,
18
+ "lr": 1e-5,
19
+ "maxsteps": 500000,
20
+ "save_every": null,
21
+ "warmup": 20000,
22
+ "warmup_bert": null,
23
+ "relu": false,
24
+ "nway": 64,
25
+ "use_ib_negatives": true,
26
+ "reranker": false,
27
+ "distillation_alpha": 1.0,
28
+ "ignore_scores": false,
29
+ "model_name": null,
30
+ "query_maxlen": 32,
31
+ "attend_to_mask_tokens": false,
32
+ "interaction": "colbert",
33
+ "dim": 128,
34
+ "doc_maxlen": 250,
35
+ "mask_punctuation": true,
36
+ "checkpoint": "bert-base-uncased",
37
+ "triples": ".\/examples.json",
38
+ "collection": ".\/collection.tsv",
39
+ "queries": ".\/queries.train.tsv",
40
+ "index_name": null,
41
+ "overwrite": false,
42
+ "root": "",
43
+ "experiment": "default",
44
+ "index_root": null,
45
+ "name": "2024-06\/30\/19.56.27",
46
+ "rank": 0,
47
+ "nranks": 8,
48
+ "amp": true,
49
+ "gpus": 8,
50
+ "avoid_fork_if_possible": false,
51
+ "meta": {
52
+ "hostname": "",
53
+ "git_branch": "main",
54
+ "git_hash": "",
55
+ "git_commit_datetime": "2024-06-26 12:50:22+00:00",
56
+ "current_datetime": "Jul 04, 2024 ; 3:51AM UTC (+0000)",
57
+ "cmd": "train.py",
58
+ "version": "colbert-v0.4"
59
+ }
60
+ }
colbert_configuration.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import __main__
4
+
5
+ import os
6
+ import ujson
7
+ from huggingface_hub import hf_hub_download
8
+ import dataclasses
9
+ import datetime
10
+ from typing import Any
11
+ from dataclasses import dataclass, fields
12
+ import socket
13
+ import git
14
+ import time
15
+ import torch
16
+ import sys
17
+
18
+ def torch_load_dnn(path):
19
+ if path.startswith("http:") or path.startswith("https:"):
20
+ dnn = torch.hub.load_state_dict_from_url(path, map_location='cpu')
21
+ else:
22
+ dnn = torch.load(path, map_location='cpu')
23
+
24
+ return dnn
25
+
26
+ class dotdict(dict):
27
+ """
28
+ dot.notation access to dictionary attributes
29
+ Credit: derek73 @ https://stackoverflow.com/questions/2352181
30
+ """
31
+ __getattr__ = dict.__getitem__
32
+ __setattr__ = dict.__setitem__
33
+ __delattr__ = dict.__delitem__
34
+
35
+ def get_metadata_only():
36
+ args = dotdict()
37
+
38
+ args.hostname = socket.gethostname()
39
+ try:
40
+ args.git_branch = git.Repo(search_parent_directories=True).active_branch.name
41
+ args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha
42
+ args.git_commit_datetime = str(git.Repo(search_parent_directories=True).head.object.committed_datetime)
43
+ except git.exc.InvalidGitRepositoryError as e:
44
+ pass
45
+ args.current_datetime = time.strftime('%b %d, %Y ; %l:%M%p %Z (%z)')
46
+ args.cmd = ' '.join(sys.argv)
47
+
48
+ return args
49
+
50
+ def timestamp(daydir=False):
51
+ format_str = f"%Y-%m{'/' if daydir else '-'}%d{'/' if daydir else '_'}%H.%M.%S"
52
+ result = datetime.datetime.now().strftime(format_str)
53
+ return result
54
+
55
+ @dataclass
56
+ class DefaultVal:
57
+ val: Any
58
+
59
+ def __hash__(self):
60
+ return hash(repr(self.val))
61
+
62
+ def __eq__(self, other):
63
+ self.val == other.val
64
+
65
+ @dataclass
66
+ class RunSettings:
67
+ """
68
+ The defaults here have a special status in Run(), which initially calls assign_defaults(),
69
+ so these aren't soft defaults in that specific context.
70
+ """
71
+
72
+ overwrite: bool = DefaultVal(False)
73
+
74
+ root: str = DefaultVal(os.path.join(os.getcwd(), 'experiments'))
75
+ experiment: str = DefaultVal('default')
76
+
77
+ index_root: str = DefaultVal(None)
78
+ name: str = DefaultVal(timestamp(daydir=True))
79
+
80
+ rank: int = DefaultVal(0)
81
+ nranks: int = DefaultVal(1)
82
+ amp: bool = DefaultVal(True)
83
+
84
+ total_visible_gpus = torch.cuda.device_count()
85
+ gpus: int = DefaultVal(total_visible_gpus)
86
+
87
+ avoid_fork_if_possible: bool = DefaultVal(False)
88
+
89
+ @property
90
+ def gpus_(self):
91
+ value = self.gpus
92
+
93
+ if isinstance(value, int):
94
+ value = list(range(value))
95
+
96
+ if isinstance(value, str):
97
+ value = value.split(',')
98
+
99
+ value = list(map(int, value))
100
+ value = sorted(list(set(value)))
101
+
102
+ assert all(device_idx in range(0, self.total_visible_gpus) for device_idx in value), value
103
+
104
+ return value
105
+
106
+ @property
107
+ def index_root_(self):
108
+ return self.index_root or os.path.join(self.root, self.experiment, 'indexes/')
109
+
110
+ @property
111
+ def script_name_(self):
112
+ if '__file__' in dir(__main__):
113
+ cwd = os.path.abspath(os.getcwd())
114
+ script_path = os.path.abspath(__main__.__file__)
115
+ root_path = os.path.abspath(self.root)
116
+
117
+ if script_path.startswith(cwd):
118
+ script_path = script_path[len(cwd):]
119
+
120
+ else:
121
+ try:
122
+ commonpath = os.path.commonpath([script_path, root_path])
123
+ script_path = script_path[len(commonpath):]
124
+ except:
125
+ pass
126
+
127
+
128
+ assert script_path.endswith('.py')
129
+ script_name = script_path.replace('/', '.').strip('.')[:-3]
130
+
131
+ assert len(script_name) > 0, (script_name, script_path, cwd)
132
+
133
+ return script_name
134
+
135
+ return 'none'
136
+
137
+ @property
138
+ def path_(self):
139
+ return os.path.join(self.root, self.experiment, self.script_name_, self.name)
140
+
141
+ @property
142
+ def device_(self):
143
+ return self.gpus_[self.rank % self.nranks]
144
+
145
+
146
+ @dataclass
147
+ class TokenizerSettings:
148
+ query_token_id: str = DefaultVal("[unused0]")
149
+ doc_token_id: str = DefaultVal("[unused1]")
150
+ query_token: str = DefaultVal("[Q]")
151
+ doc_token: str = DefaultVal("[D]")
152
+
153
+
154
+ @dataclass
155
+ class ResourceSettings:
156
+ checkpoint: str = DefaultVal(None)
157
+ triples: str = DefaultVal(None)
158
+ collection: str = DefaultVal(None)
159
+ queries: str = DefaultVal(None)
160
+ index_name: str = DefaultVal(None)
161
+
162
+
163
+ @dataclass
164
+ class DocSettings:
165
+ dim: int = DefaultVal(128)
166
+ doc_maxlen: int = DefaultVal(220)
167
+ mask_punctuation: bool = DefaultVal(True)
168
+
169
+
170
+ @dataclass
171
+ class QuerySettings:
172
+ query_maxlen: int = DefaultVal(32)
173
+ attend_to_mask_tokens : bool = DefaultVal(False)
174
+ interaction: str = DefaultVal('colbert')
175
+
176
+
177
+ @dataclass
178
+ class TrainingSettings:
179
+ similarity: str = DefaultVal('cosine')
180
+
181
+ bsize: int = DefaultVal(32)
182
+
183
+ accumsteps: int = DefaultVal(1)
184
+
185
+ lr: float = DefaultVal(3e-06)
186
+
187
+ maxsteps: int = DefaultVal(500_000)
188
+
189
+ save_every: int = DefaultVal(None)
190
+
191
+ resume: bool = DefaultVal(False)
192
+
193
+ ## NEW:
194
+ warmup: int = DefaultVal(None)
195
+
196
+ warmup_bert: int = DefaultVal(None)
197
+
198
+ relu: bool = DefaultVal(False)
199
+
200
+ nway: int = DefaultVal(2)
201
+
202
+ use_ib_negatives: bool = DefaultVal(False)
203
+
204
+ reranker: bool = DefaultVal(False)
205
+
206
+ distillation_alpha: float = DefaultVal(1.0)
207
+
208
+ ignore_scores: bool = DefaultVal(False)
209
+
210
+ model_name: str = DefaultVal(None) # DefaultVal('bert-base-uncased')
211
+
212
+ @dataclass
213
+ class IndexingSettings:
214
+ index_path: str = DefaultVal(None)
215
+
216
+ index_bsize: int = DefaultVal(64)
217
+
218
+ nbits: int = DefaultVal(1)
219
+
220
+ kmeans_niters: int = DefaultVal(4)
221
+
222
+ resume: bool = DefaultVal(False)
223
+
224
+ @property
225
+ def index_path_(self):
226
+ return self.index_path or os.path.join(self.index_root_, self.index_name)
227
+
228
+ @dataclass
229
+ class SearchSettings:
230
+ ncells: int = DefaultVal(None)
231
+ centroid_score_threshold: float = DefaultVal(None)
232
+ ndocs: int = DefaultVal(None)
233
+ load_index_with_mmap: bool = DefaultVal(False)
234
+
235
+
236
+ @dataclass
237
+ class CoreConfig:
238
+ def __post_init__(self):
239
+ """
240
+ Source: https://stackoverflow.com/a/58081120/1493011
241
+ """
242
+
243
+ self.assigned = {}
244
+
245
+ for field in fields(self):
246
+ field_val = getattr(self, field.name)
247
+
248
+ if isinstance(field_val, DefaultVal) or field_val is None:
249
+ setattr(self, field.name, field.default.val)
250
+
251
+ if not isinstance(field_val, DefaultVal):
252
+ self.assigned[field.name] = True
253
+
254
+ def assign_defaults(self):
255
+ for field in fields(self):
256
+ setattr(self, field.name, field.default.val)
257
+ self.assigned[field.name] = True
258
+
259
+ def configure(self, ignore_unrecognized=True, **kw_args):
260
+ ignored = set()
261
+
262
+ for key, value in kw_args.items():
263
+ self.set(key, value, ignore_unrecognized) or ignored.update({key})
264
+
265
+ return ignored
266
+
267
+ """
268
+ # TODO: Take a config object, not kw_args.
269
+
270
+ for key in config.assigned:
271
+ value = getattr(config, key)
272
+ """
273
+
274
+ def set(self, key, value, ignore_unrecognized=False):
275
+ if hasattr(self, key):
276
+ setattr(self, key, value)
277
+ self.assigned[key] = True
278
+ return True
279
+
280
+ if not ignore_unrecognized:
281
+ raise Exception(f"Unrecognized key `{key}` for {type(self)}")
282
+
283
+ def help(self):
284
+ print(ujson.dumps(self.export(), indent=4))
285
+
286
+ def __export_value(self, v):
287
+ v = v.provenance() if hasattr(v, 'provenance') else v
288
+
289
+ if isinstance(v, list) and len(v) > 100:
290
+ v = (f"list with {len(v)} elements starting with...", v[:3])
291
+
292
+ if isinstance(v, dict) and len(v) > 100:
293
+ v = (f"dict with {len(v)} keys starting with...", list(v.keys())[:3])
294
+
295
+ return v
296
+
297
+ def export(self):
298
+ d = dataclasses.asdict(self)
299
+
300
+ for k, v in d.items():
301
+ d[k] = self.__export_value(v)
302
+
303
+ return d
304
+
305
+ @dataclass
306
+ class BaseConfig(CoreConfig):
307
+ @classmethod
308
+ def from_existing(cls, *sources):
309
+ kw_args = {}
310
+
311
+ for source in sources:
312
+ if source is None:
313
+ continue
314
+
315
+ local_kw_args = dataclasses.asdict(source)
316
+ local_kw_args = {k: local_kw_args[k] for k in source.assigned}
317
+ kw_args = {**kw_args, **local_kw_args}
318
+
319
+ obj = cls(**kw_args)
320
+
321
+ return obj
322
+
323
+ @classmethod
324
+ def from_deprecated_args(cls, args):
325
+ obj = cls()
326
+ ignored = obj.configure(ignore_unrecognized=True, **args)
327
+
328
+ return obj, ignored
329
+
330
+ @classmethod
331
+ def from_path(cls, name):
332
+ with open(name) as f:
333
+ args = ujson.load(f)
334
+
335
+ if "config" in args:
336
+ args = args["config"]
337
+
338
+ return cls.from_deprecated_args(
339
+ args
340
+ ) # the new, non-deprecated version functions the same at this level.
341
+
342
+ @classmethod
343
+ def load_from_checkpoint(cls, checkpoint_path):
344
+ if checkpoint_path.endswith(".dnn"):
345
+ dnn = torch_load_dnn(checkpoint_path)
346
+ config, _ = cls.from_deprecated_args(dnn.get("arguments", {}))
347
+
348
+ # TODO: FIXME: Decide if the line below will have any unintended consequences. We don't want to overwrite those!
349
+ config.set("checkpoint", checkpoint_path)
350
+
351
+ return config
352
+
353
+ try:
354
+ checkpoint_path = hf_hub_download(
355
+ repo_id=checkpoint_path, filename="artifact.metadata"
356
+ ).split("artifact")[0]
357
+ except Exception:
358
+ pass
359
+ loaded_config_path = os.path.join(checkpoint_path, "artifact.metadata")
360
+ if os.path.exists(loaded_config_path):
361
+ loaded_config, _ = cls.from_path(loaded_config_path)
362
+ loaded_config.set("checkpoint", checkpoint_path)
363
+
364
+ return loaded_config
365
+
366
+ return (
367
+ None # can happen if checkpoint_path is something like 'bert-base-uncased'
368
+ )
369
+
370
+ @classmethod
371
+ def load_from_index(cls, index_path):
372
+ # FIXME: We should start here with initial_config = ColBERTConfig(config, Run().config).
373
+ # This should allow us to say initial_config.index_root. Then, below, set config = Config(..., initial_c)
374
+
375
+ # default_index_root = os.path.join(Run().root, Run().experiment, 'indexes/')
376
+ # index_path = os.path.join(default_index_root, index_path)
377
+
378
+ # CONSIDER: No more plan/metadata.json. Only metadata.json to avoid weird issues when loading an index.
379
+
380
+ try:
381
+ metadata_path = os.path.join(index_path, "metadata.json")
382
+ loaded_config, _ = cls.from_path(metadata_path)
383
+ except:
384
+ metadata_path = os.path.join(index_path, "plan.json")
385
+ loaded_config, _ = cls.from_path(metadata_path)
386
+
387
+ return loaded_config
388
+
389
+ def save(self, path, overwrite=False):
390
+ assert overwrite or not os.path.exists(path), path
391
+
392
+ with open(path, "w") as f:
393
+ args = self.export() # dict(self.__config)
394
+ args["meta"] = get_metadata_only()
395
+ args["meta"]["version"] = "colbert-v0.4"
396
+ # TODO: Add git_status details.. It can't be too large! It should be a path that Runs() saves on exit, maybe!
397
+
398
+ f.write(ujson.dumps(args, indent=4) + "\n")
399
+
400
+ def save_for_checkpoint(self, checkpoint_path):
401
+ assert not checkpoint_path.endswith(
402
+ ".dnn"
403
+ ), f"{checkpoint_path}: We reserve *.dnn names for the deprecated checkpoint format."
404
+
405
+ output_config_path = os.path.join(checkpoint_path, "artifact.metadata")
406
+ self.save(output_config_path, overwrite=True)
407
+
408
+
409
+ @dataclass
410
+ class ColBERTConfig(RunSettings, ResourceSettings, DocSettings, QuerySettings, TrainingSettings,
411
+ IndexingSettings, SearchSettings, BaseConfig, TokenizerSettings):
412
+ pass
tokenization_utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from constbert.colbert_configuration import ColBERTConfig
3
+ from transformers import AutoTokenizer
4
+
5
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
+
7
+ def _split_into_batches(ids, mask, bsize):
8
+ batches = []
9
+ for offset in range(0, ids.size(0), bsize):
10
+ batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize]))
11
+
12
+ return batches
13
+
14
+ def _sort_by_length(ids, mask, bsize):
15
+ if ids.size(0) <= bsize:
16
+ return ids, mask, torch.arange(ids.size(0))
17
+
18
+ indices = mask.sum(-1).sort().indices
19
+ reverse_indices = indices.sort().indices
20
+
21
+ return ids[indices], mask[indices], reverse_indices
22
+
23
+ class QueryTokenizer():
24
+ def __init__(self, config: ColBERTConfig, verbose: int = 3):
25
+ self.tok = AutoTokenizer.from_pretrained(config.checkpoint)
26
+ self.tok.base = config.checkpoint
27
+ self.verbose = verbose
28
+
29
+ self.config = config
30
+ self.query_maxlen = config.query_maxlen
31
+ self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
32
+
33
+ self.Q_marker_token, self.Q_marker_token_id = config.query_token, self.tok.convert_tokens_to_ids(config.query_token_id)
34
+ self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
35
+ self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
36
+ self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id
37
+ self.pad_token,self.pad_token_id = self.tok.pad_token,self.tok.pad_token_id
38
+ self.used = False
39
+
40
+ def tokenize(self, batch_text, add_special_tokens=False):
41
+ assert type(batch_text) in [list, tuple], (type(batch_text))
42
+
43
+ tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]
44
+
45
+ if not add_special_tokens:
46
+ return tokens
47
+
48
+ prefix, suffix = [self.cls_token, self.Q_marker_token], [self.sep_token]
49
+ tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst)+3)) for lst in tokens]
50
+
51
+ return tokens
52
+
53
+ def encode(self, batch_text, add_special_tokens=False):
54
+ assert type(batch_text) in [list, tuple], (type(batch_text))
55
+
56
+ ids = self.tok(batch_text, add_special_tokens=False).to(DEVICE)['input_ids']
57
+
58
+ if not add_special_tokens:
59
+ return ids
60
+
61
+ prefix, suffix = [self.cls_token_id, self.Q_marker_token_id], [self.sep_token_id]
62
+ ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst)+3)) for lst in ids]
63
+
64
+ return ids
65
+
66
+ def tensorize(self, batch_text, bsize=None, context=None, full_length_search=False):
67
+ assert type(batch_text) in [list, tuple], (type(batch_text))
68
+
69
+ # add placehold for the [Q] marker
70
+ batch_text = ['. ' + x for x in batch_text]
71
+
72
+ # Full length search is only available for single inference (for now)
73
+ # Batched full length search requires far deeper changes to the code base
74
+ assert(full_length_search == False or (type(batch_text) == list and len(batch_text) == 1))
75
+
76
+ if full_length_search:
77
+ # Tokenize each string in the batch
78
+ un_truncated_ids = self.tok(batch_text, add_special_tokens=False).to(DEVICE)['input_ids']
79
+ # Get the longest length in the batch
80
+ max_length_in_batch = max(len(x) for x in un_truncated_ids)
81
+ # Set the max length
82
+ max_length = self.max_len(max_length_in_batch)
83
+ else:
84
+ # Max length is the default max length from the config
85
+ max_length = self.query_maxlen
86
+
87
+ obj = self.tok(batch_text, padding='max_length', truncation=True,
88
+ return_tensors='pt', max_length=max_length).to(DEVICE)
89
+
90
+ ids, mask = obj['input_ids'], obj['attention_mask']
91
+
92
+ # postprocess for the [Q] marker and the [MASK] augmentation
93
+ ids[:, 1] = self.Q_marker_token_id
94
+ ids[ids == self.pad_token_id] = self.mask_token_id
95
+
96
+ if context is not None:
97
+ assert len(context) == len(batch_text), (len(context), len(batch_text))
98
+
99
+ obj_2 = self.tok(context, padding='longest', truncation=True,
100
+ return_tensors='pt', max_length=self.background_maxlen).to(DEVICE)
101
+
102
+ ids_2, mask_2 = obj_2['input_ids'][:, 1:], obj_2['attention_mask'][:, 1:] # Skip the first [SEP]
103
+
104
+ ids = torch.cat((ids, ids_2), dim=-1)
105
+ mask = torch.cat((mask, mask_2), dim=-1)
106
+
107
+ if self.config.attend_to_mask_tokens:
108
+ mask[ids == self.mask_token_id] = 1
109
+ assert mask.sum().item() == mask.size(0) * mask.size(1), mask
110
+
111
+ if bsize:
112
+ batches = _split_into_batches(ids, mask, bsize)
113
+ return batches
114
+
115
+ if self.used is False:
116
+ self.used = True
117
+
118
+ firstbg = (context is None) or context[0]
119
+ if self.verbose > 1:
120
+ print()
121
+ print("#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==")
122
+ print(f"#> Input: {batch_text[0]}, \t\t {firstbg}, \t\t {bsize}")
123
+ print(f"#> Output IDs: {ids[0].size()}, {ids[0]}")
124
+ print(f"#> Output Mask: {mask[0].size()}, {mask[0]}")
125
+ print()
126
+
127
+ return ids, mask
128
+
129
+ # Ensure that query_maxlen <= length <= 500 tokens
130
+ def max_len(self, length):
131
+ return min(500, max(self.query_maxlen, length))
132
+
133
+
134
+ class DocTokenizer():
135
+ def __init__(self, config: ColBERTConfig):
136
+ self.tok = AutoTokenizer.from_pretrained(config.checkpoint)
137
+ self.tok.base = config.checkpoint
138
+
139
+ self.config = config
140
+ self.doc_maxlen = config.doc_maxlen
141
+
142
+ self.D_marker_token, self.D_marker_token_id = self.config.doc_token, self.tok.convert_tokens_to_ids(self.config.doc_token_id)
143
+ self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
144
+ self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
145
+
146
+ def tokenize(self, batch_text, add_special_tokens=False):
147
+ assert type(batch_text) in [list, tuple], (type(batch_text))
148
+
149
+ tokens = [self.tok.tokenize(x, add_special_tokens=False).to(DEVICE) for x in batch_text]
150
+
151
+ if not add_special_tokens:
152
+ return tokens
153
+
154
+ prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token]
155
+ tokens = [prefix + lst + suffix for lst in tokens]
156
+
157
+ return tokens
158
+
159
+ def encode(self, batch_text, add_special_tokens=False):
160
+ assert type(batch_text) in [list, tuple], (type(batch_text))
161
+
162
+ ids = self.tok(batch_text, add_special_tokens=False).to(DEVICE)['input_ids']
163
+
164
+ if not add_special_tokens:
165
+ return ids
166
+
167
+ prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id]
168
+ ids = [prefix + lst + suffix for lst in ids]
169
+
170
+ return ids
171
+
172
+ def tensorize(self, batch_text, bsize=None):
173
+ assert type(batch_text) in [list, tuple], (type(batch_text))
174
+
175
+ # add placehold for the [D] marker
176
+ batch_text = ['. ' + x for x in batch_text]
177
+
178
+ obj = self.tok(batch_text, padding='max_length', truncation='longest_first',
179
+ return_tensors='pt', max_length=self.doc_maxlen).to(DEVICE)
180
+
181
+ ids, mask = obj['input_ids'], obj['attention_mask']
182
+
183
+ # postprocess for the [D] marker
184
+ ids[:, 1] = self.D_marker_token_id
185
+
186
+ if bsize:
187
+ ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
188
+ batches = _split_into_batches(ids, mask, bsize)
189
+ return batches, reverse_indices
190
+
191
+ return ids, mask