ConstBERT / modeling.py
caesar-one's picture
Upload ConstBERT
47d8577 verified
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
from tqdm import tqdm
from .colbert_configuration import ColBERTConfig
from .tokenization_utils import QueryTokenizer, DocTokenizer
import os
class NullContextManager(object):
def __init__(self, dummy_resource=None):
self.dummy_resource = dummy_resource
def __enter__(self):
return self.dummy_resource
def __exit__(self, *args):
pass
class MixedPrecisionManager():
def __init__(self, activated):
self.activated = activated
if self.activated:
self.scaler = torch.amp.GradScaler("cuda")
def context(self):
return torch.amp.autocast("cuda") if self.activated else NullContextManager()
def backward(self, loss):
if self.activated:
self.scaler.scale(loss).backward()
else:
loss.backward()
def step(self, colbert, optimizer, scheduler=None):
if self.activated:
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0, error_if_nonfinite=False)
self.scaler.step(optimizer)
self.scaler.update()
else:
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0)
optimizer.step()
if scheduler is not None:
scheduler.step()
optimizer.zero_grad()
class ConstBERT(BertPreTrainedModel):
"""
Shallow wrapper around HuggingFace transformers. All new parameters should be defined at this level.
This makes sure `{from,save}_pretrained` and `init_weights` are applied to new parameters correctly.
"""
_keys_to_ignore_on_load_unexpected = [r"cls"]
def __init__(self, config, colbert_config, verbose:int = 0):
super().__init__(config)
self.config = config
self.dim = colbert_config.dim
self.linear = nn.Linear(config.hidden_size, colbert_config.dim, bias=False)
self.doc_project = nn.Linear(colbert_config.doc_maxlen, 32, bias=False)
self.query_project = nn.Linear(colbert_config.query_maxlen, 64, bias=False)
## Download required tokenizer files from Hugging Face
if not os.path.exists(os.path.join(colbert_config.name_or_path, "tokenizer.json")):
hf_hub_download(repo_id=colbert_config.name_or_path, filename="tokenizer.json")
if not os.path.exists(os.path.join(colbert_config.name_or_path, "vocab.txt")):
hf_hub_download(repo_id=colbert_config.name_or_path, filename="vocab.txt")
if not os.path.exists(os.path.join(colbert_config.name_or_path, "tokenizer_config.json")):
hf_hub_download(repo_id=colbert_config.name_or_path, filename="tokenizer_config.json")
if not os.path.exists(os.path.join(colbert_config.name_or_path, "special_tokens_map.json")):
hf_hub_download(repo_id=colbert_config.name_or_path, filename="special_tokens_map.json")
self.query_tokenizer = QueryTokenizer(colbert_config, verbose=verbose)
self.doc_tokenizer = DocTokenizer(colbert_config)
self.amp_manager = MixedPrecisionManager(True)
self.raw_tokenizer = AutoTokenizer.from_pretrained(colbert_config.checkpoint)
self.pad_token = self.raw_tokenizer.pad_token_id
self.use_gpu = colbert_config.total_visible_gpus > 0
setattr(self,self.base_model_prefix, BertModel(config))
# if colbert_config.relu:
# self.score_scaler = nn.Linear(1, 1)
self.init_weights()
# if colbert_config.relu:
# self.score_scaler.weight.data.fill_(1.0)
# self.score_scaler.bias.data.fill_(-8.0)
@property
def LM(self):
base_model_prefix = getattr(self, "base_model_prefix")
return getattr(self, base_model_prefix)
@classmethod
def from_pretrained(cls, name_or_path, config=None, *args, **kwargs):
colbert_config = ColBERTConfig(name_or_path)
colbert_config = ColBERTConfig.from_existing(ColBERTConfig.load_from_checkpoint(name_or_path), colbert_config)
obj = super().from_pretrained(name_or_path, colbert_config=colbert_config, config=config)
obj.base = name_or_path
return obj
@staticmethod
def raw_tokenizer_from_pretrained(name_or_path):
obj = AutoTokenizer.from_pretrained(name_or_path)
obj.base = name_or_path
return obj
def _query(self, input_ids, attention_mask):
input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
Q = self.bert(input_ids, attention_mask=attention_mask)[0]
# Q = Q.permute(0, 2, 1) #(64, 128,32)
# Q = self.query_project(Q) #(64, 128,8)
# Q = Q.permute(0, 2, 1) #(64,8,128)
Q = self.linear(Q)
# mask = torch.ones(Q.shape[0], Q.shape[1], device=self.device).unsqueeze(2).float()
mask = torch.tensor(self.mask(input_ids, skiplist=[]), device=self.device).unsqueeze(2).float()
Q = Q * mask
return torch.nn.functional.normalize(Q, p=2, dim=2)
def _doc(self, input_ids, attention_mask, keep_dims=True):
assert keep_dims in [True, False, 'return_mask']
input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
D = self.bert(input_ids, attention_mask=attention_mask)[0]
D = D.permute(0, 2, 1) #(64, 128,180)
D = self.doc_project(D) #(64, 128,16)
D = D.permute(0, 2, 1) #(64,16,128)
D = self.linear(D)
mask = torch.ones(D.shape[0], D.shape[1], device=self.device).unsqueeze(2).float()
# mask = torch.tensor(self.mask(input_ids, skiplist=self.skiplist), device=self.device).unsqueeze(2).float()
D = D * mask
D = torch.nn.functional.normalize(D, p=2, dim=2)
if self.use_gpu:
D = D.half()
if keep_dims is False:
D, mask = D.cpu(), mask.bool().cpu().squeeze(-1)
D = [d[mask[idx]] for idx, d in enumerate(D)]
elif keep_dims == 'return_mask':
return D, mask.bool()
return D
def mask(self, input_ids, skiplist):
mask = [[(x not in skiplist) and (x != self.pad_token) for x in d] for d in input_ids.cpu().tolist()]
return mask
def query(self, *args, to_cpu=False, **kw_args):
with torch.no_grad():
with self.amp_manager.context():
Q = self._query(*args, **kw_args)
return Q.cpu() if to_cpu else Q
def doc(self, *args, to_cpu=False, **kw_args):
with torch.no_grad():
with self.amp_manager.context():
D = self._doc(*args, **kw_args)
if to_cpu:
return (D[0].cpu(), *D[1:]) if isinstance(D, tuple) else D.cpu()
return D
def encode_queries(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
if type(queries) == str:
queries = [queries]
if bsize:
batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
return torch.cat(batches)
input_ids, attention_mask = self.query_tokenizer.tensorize(queries, context=context, full_length_search=full_length_search)
return self.query(input_ids, attention_mask)
def encode_documents(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):
if type(docs) == str:
docs = [docs]
assert keep_dims in [True, False, 'flatten']
if bsize:
text_batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)
returned_text = []
if return_tokens:
returned_text = [text for batch in text_batches for text in batch[0]]
returned_text = [returned_text[idx] for idx in reverse_indices.tolist()]
returned_text = [returned_text]
keep_dims_ = 'return_mask' if keep_dims == 'flatten' else keep_dims
batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims_, to_cpu=to_cpu)
for input_ids, attention_mask in tqdm(text_batches, disable=not showprogress)]
if keep_dims is True:
D = _stack_3D_tensors(batches)
return (D[reverse_indices], *returned_text)
elif keep_dims == 'flatten':
D, mask = [], []
for D_, mask_ in batches:
D.append(D_)
mask.append(mask_)
D, mask = torch.cat(D)[reverse_indices], torch.cat(mask)[reverse_indices]
doclens = mask.squeeze(-1).sum(-1).tolist()
D = D.view(-1, self.colbert_config.dim)
D = D[mask.bool().flatten()].cpu()
return (D, doclens, *returned_text)
assert keep_dims is False
D = [d for batch in batches for d in batch]
return ([D[idx] for idx in reverse_indices.tolist()], *returned_text)
input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
return self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu)
def _stack_3D_tensors(groups):
bsize = sum([x.size(0) for x in groups])
maxlen = max([x.size(1) for x in groups])
hdim = groups[0].size(2)
output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype)
offset = 0
for x in groups:
endpos = offset + x.size(0)
output[offset:endpos, :x.size(1)] = x
offset = endpos
return output