ConstBERT / modeling.py
caesar-one's picture
Upload ConstBERT
03562ce verified
raw
history blame
8.88 kB
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel, AutoTokenizer
import torch
from tqdm import tqdm
from transformers import AutoTokenizer
from constbert.colbert_configuration import ColBERTConfig
from constbert.tokenization_utils import QueryTokenizer, DocTokenizer
class NullContextManager(object):
def __init__(self, dummy_resource=None):
self.dummy_resource = dummy_resource
def __enter__(self):
return self.dummy_resource
def __exit__(self, *args):
pass
class MixedPrecisionManager():
def __init__(self, activated):
self.activated = activated
if self.activated:
self.scaler = torch.cuda.amp.GradScaler()
def context(self):
return torch.cuda.amp.autocast() if self.activated else NullContextManager()
def backward(self, loss):
if self.activated:
self.scaler.scale(loss).backward()
else:
loss.backward()
def step(self, colbert, optimizer, scheduler=None):
if self.activated:
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0, error_if_nonfinite=False)
self.scaler.step(optimizer)
self.scaler.update()
else:
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0)
optimizer.step()
if scheduler is not None:
scheduler.step()
optimizer.zero_grad()
class ConstBERT(BertPreTrainedModel):
"""
Shallow wrapper around HuggingFace transformers. All new parameters should be defined at this level.
This makes sure `{from,save}_pretrained` and `init_weights` are applied to new parameters correctly.
"""
_keys_to_ignore_on_load_unexpected = [r"cls"]
def __init__(self, config, colbert_config, verbose:int = 3):
super().__init__(config)
self.config = config
self.dim = colbert_config.dim
self.linear = nn.Linear(config.hidden_size, colbert_config.dim, bias=False)
self.doc_project = nn.Linear(colbert_config.doc_maxlen, 32, bias=False)
self.query_project = nn.Linear(colbert_config.query_maxlen, 64, bias=False)
self.query_tokenizer = QueryTokenizer(colbert_config, verbose=verbose)
self.doc_tokenizer = DocTokenizer(colbert_config)
self.amp_manager = MixedPrecisionManager(True)
self.raw_tokenizer = AutoTokenizer.from_pretrained(colbert_config.checkpoint)
self.pad_token = self.raw_tokenizer.pad_token_id
self.use_gpu = colbert_config.total_visible_gpus > 0
setattr(self,self.base_model_prefix, BertModel(config))
# if colbert_config.relu:
# self.score_scaler = nn.Linear(1, 1)
self.init_weights()
# if colbert_config.relu:
# self.score_scaler.weight.data.fill_(1.0)
# self.score_scaler.bias.data.fill_(-8.0)
@property
def LM(self):
base_model_prefix = getattr(self, "base_model_prefix")
return getattr(self, base_model_prefix)
@classmethod
def from_pretrained(cls, name_or_path):
colbert_config = ColBERTConfig(name_or_path)
colbert_config = ColBERTConfig.from_existing(ColBERTConfig.load_from_checkpoint(name_or_path), colbert_config)
obj = super().from_pretrained(name_or_path, colbert_config=colbert_config)
obj.base = name_or_path
return obj
@staticmethod
def raw_tokenizer_from_pretrained(name_or_path):
obj = AutoTokenizer.from_pretrained(name_or_path)
obj.base = name_or_path
return obj
def _query(self, input_ids, attention_mask):
input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
Q = self.bert(input_ids, attention_mask=attention_mask)[0]
# Q = Q.permute(0, 2, 1) #(64, 128,32)
# Q = self.query_project(Q) #(64, 128,8)
# Q = Q.permute(0, 2, 1) #(64,8,128)
Q = self.linear(Q)
# mask = torch.ones(Q.shape[0], Q.shape[1], device=self.device).unsqueeze(2).float()
mask = torch.tensor(self.mask(input_ids, skiplist=[]), device=self.device).unsqueeze(2).float()
Q = Q * mask
return torch.nn.functional.normalize(Q, p=2, dim=2)
def _doc(self, input_ids, attention_mask, keep_dims=True):
assert keep_dims in [True, False, 'return_mask']
input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
D = self.bert(input_ids, attention_mask=attention_mask)[0]
D = D.permute(0, 2, 1) #(64, 128,180)
D = self.doc_project(D) #(64, 128,16)
D = D.permute(0, 2, 1) #(64,16,128)
D = self.linear(D)
mask = torch.ones(D.shape[0], D.shape[1], device=self.device).unsqueeze(2).float()
# mask = torch.tensor(self.mask(input_ids, skiplist=self.skiplist), device=self.device).unsqueeze(2).float()
D = D * mask
D = torch.nn.functional.normalize(D, p=2, dim=2)
if self.use_gpu:
D = D.half()
if keep_dims is False:
D, mask = D.cpu(), mask.bool().cpu().squeeze(-1)
D = [d[mask[idx]] for idx, d in enumerate(D)]
elif keep_dims == 'return_mask':
return D, mask.bool()
return D
def mask(self, input_ids, skiplist):
mask = [[(x not in skiplist) and (x != self.pad_token) for x in d] for d in input_ids.cpu().tolist()]
return mask
def query(self, *args, to_cpu=False, **kw_args):
with torch.no_grad():
with self.amp_manager.context():
Q = self._query(*args, **kw_args)
return Q.cpu() if to_cpu else Q
def doc(self, *args, to_cpu=False, **kw_args):
with torch.no_grad():
with self.amp_manager.context():
D = self._doc(*args, **kw_args)
if to_cpu:
return (D[0].cpu(), *D[1:]) if isinstance(D, tuple) else D.cpu()
return D
def queryFromText(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
if bsize:
batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
return torch.cat(batches)
input_ids, attention_mask = self.query_tokenizer.tensorize(queries, context=context, full_length_search=full_length_search)
return self.query(input_ids, attention_mask)
def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):
assert keep_dims in [True, False, 'flatten']
if bsize:
text_batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)
returned_text = []
if return_tokens:
returned_text = [text for batch in text_batches for text in batch[0]]
returned_text = [returned_text[idx] for idx in reverse_indices.tolist()]
returned_text = [returned_text]
keep_dims_ = 'return_mask' if keep_dims == 'flatten' else keep_dims
batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims_, to_cpu=to_cpu)
for input_ids, attention_mask in tqdm(text_batches, disable=not showprogress)]
if keep_dims is True:
D = _stack_3D_tensors(batches)
return (D[reverse_indices], *returned_text)
elif keep_dims == 'flatten':
D, mask = [], []
for D_, mask_ in batches:
D.append(D_)
mask.append(mask_)
D, mask = torch.cat(D)[reverse_indices], torch.cat(mask)[reverse_indices]
doclens = mask.squeeze(-1).sum(-1).tolist()
D = D.view(-1, self.colbert_config.dim)
D = D[mask.bool().flatten()].cpu()
return (D, doclens, *returned_text)
assert keep_dims is False
D = [d for batch in batches for d in batch]
return ([D[idx] for idx in reverse_indices.tolist()], *returned_text)
input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
return self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu)
def _stack_3D_tensors(groups):
bsize = sum([x.size(0) for x in groups])
maxlen = max([x.size(1) for x in groups])
hdim = groups[0].size(2)
output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype)
offset = 0
for x in groups:
endpos = offset + x.size(0)
output[offset:endpos, :x.size(1)] = x
offset = endpos
return output