File size: 7,650 Bytes
b5eb6ef a6be4c2 b5eb6ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import torch
from .colbert_configuration import ColBERTConfig
from transformers import AutoTokenizer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _split_into_batches(ids, mask, bsize):
batches = []
for offset in range(0, ids.size(0), bsize):
batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize]))
return batches
def _sort_by_length(ids, mask, bsize):
if ids.size(0) <= bsize:
return ids, mask, torch.arange(ids.size(0))
indices = mask.sum(-1).sort().indices
reverse_indices = indices.sort().indices
return ids[indices], mask[indices], reverse_indices
class QueryTokenizer():
def __init__(self, config: ColBERTConfig, verbose: int = 3):
self.tok = AutoTokenizer.from_pretrained(config.checkpoint)
self.tok.base = config.checkpoint
self.verbose = verbose
self.config = config
self.query_maxlen = config.query_maxlen
self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
self.Q_marker_token, self.Q_marker_token_id = config.query_token, self.tok.convert_tokens_to_ids(config.query_token_id)
self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id
self.pad_token,self.pad_token_id = self.tok.pad_token,self.tok.pad_token_id
self.used = False
def tokenize(self, batch_text, add_special_tokens=False):
assert type(batch_text) in [list, tuple], (type(batch_text))
tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]
if not add_special_tokens:
return tokens
prefix, suffix = [self.cls_token, self.Q_marker_token], [self.sep_token]
tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst)+3)) for lst in tokens]
return tokens
def encode(self, batch_text, add_special_tokens=False):
assert type(batch_text) in [list, tuple], (type(batch_text))
ids = self.tok(batch_text, add_special_tokens=False).to(DEVICE)['input_ids']
if not add_special_tokens:
return ids
prefix, suffix = [self.cls_token_id, self.Q_marker_token_id], [self.sep_token_id]
ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst)+3)) for lst in ids]
return ids
def tensorize(self, batch_text, bsize=None, context=None, full_length_search=False):
assert type(batch_text) in [list, tuple], (type(batch_text))
# add placehold for the [Q] marker
batch_text = ['. ' + x for x in batch_text]
# Full length search is only available for single inference (for now)
# Batched full length search requires far deeper changes to the code base
assert(full_length_search == False or (type(batch_text) == list and len(batch_text) == 1))
if full_length_search:
# Tokenize each string in the batch
un_truncated_ids = self.tok(batch_text, add_special_tokens=False).to(DEVICE)['input_ids']
# Get the longest length in the batch
max_length_in_batch = max(len(x) for x in un_truncated_ids)
# Set the max length
max_length = self.max_len(max_length_in_batch)
else:
# Max length is the default max length from the config
max_length = self.query_maxlen
obj = self.tok(batch_text, padding='max_length', truncation=True,
return_tensors='pt', max_length=max_length).to(DEVICE)
ids, mask = obj['input_ids'], obj['attention_mask']
# postprocess for the [Q] marker and the [MASK] augmentation
ids[:, 1] = self.Q_marker_token_id
ids[ids == self.pad_token_id] = self.mask_token_id
if context is not None:
assert len(context) == len(batch_text), (len(context), len(batch_text))
obj_2 = self.tok(context, padding='longest', truncation=True,
return_tensors='pt', max_length=self.background_maxlen).to(DEVICE)
ids_2, mask_2 = obj_2['input_ids'][:, 1:], obj_2['attention_mask'][:, 1:] # Skip the first [SEP]
ids = torch.cat((ids, ids_2), dim=-1)
mask = torch.cat((mask, mask_2), dim=-1)
if self.config.attend_to_mask_tokens:
mask[ids == self.mask_token_id] = 1
assert mask.sum().item() == mask.size(0) * mask.size(1), mask
if bsize:
batches = _split_into_batches(ids, mask, bsize)
return batches
if self.used is False:
self.used = True
firstbg = (context is None) or context[0]
if self.verbose > 1:
print()
print("#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==")
print(f"#> Input: {batch_text[0]}, \t\t {firstbg}, \t\t {bsize}")
print(f"#> Output IDs: {ids[0].size()}, {ids[0]}")
print(f"#> Output Mask: {mask[0].size()}, {mask[0]}")
print()
return ids, mask
# Ensure that query_maxlen <= length <= 500 tokens
def max_len(self, length):
return min(500, max(self.query_maxlen, length))
class DocTokenizer():
def __init__(self, config: ColBERTConfig):
self.tok = AutoTokenizer.from_pretrained(config.checkpoint)
self.tok.base = config.checkpoint
self.config = config
self.doc_maxlen = config.doc_maxlen
self.D_marker_token, self.D_marker_token_id = self.config.doc_token, self.tok.convert_tokens_to_ids(self.config.doc_token_id)
self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
def tokenize(self, batch_text, add_special_tokens=False):
assert type(batch_text) in [list, tuple], (type(batch_text))
tokens = [self.tok.tokenize(x, add_special_tokens=False).to(DEVICE) for x in batch_text]
if not add_special_tokens:
return tokens
prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token]
tokens = [prefix + lst + suffix for lst in tokens]
return tokens
def encode(self, batch_text, add_special_tokens=False):
assert type(batch_text) in [list, tuple], (type(batch_text))
ids = self.tok(batch_text, add_special_tokens=False).to(DEVICE)['input_ids']
if not add_special_tokens:
return ids
prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id]
ids = [prefix + lst + suffix for lst in ids]
return ids
def tensorize(self, batch_text, bsize=None):
assert type(batch_text) in [list, tuple], (type(batch_text))
# add placehold for the [D] marker
batch_text = ['. ' + x for x in batch_text]
obj = self.tok(batch_text, padding='max_length', truncation='longest_first',
return_tensors='pt', max_length=self.doc_maxlen).to(DEVICE)
ids, mask = obj['input_ids'], obj['attention_mask']
# postprocess for the [D] marker
ids[:, 1] = self.D_marker_token_id
if bsize:
ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
batches = _split_into_batches(ids, mask, bsize)
return batches, reverse_indices
return ids, mask |