LTEnjoy's picture
Upload 21 files
52da96f verified
raw
history blame
3.13 kB
import torch
from tqdm import tqdm
from torch.nn.functional import normalize
from transformers import BertConfig, BertModel, BertTokenizer
class TextEncoder(torch.nn.Module):
def __init__(self,
config_path: str,
out_dim: int,
load_pretrained: bool = True,
gradient_checkpointing: bool = False):
"""
Args:
config_path: Path to the config file
out_dim: Output dimension of the text representation
load_pretrained: Whether to load pretrained weights
gradient_checkpointing: Whether to enable gradient checkpointing
"""
super().__init__()
config = BertConfig.from_pretrained(config_path)
if load_pretrained:
self.model = BertModel.from_pretrained(config_path, add_pooling_layer=False)
else:
self.model = BertModel(config, add_pooling_layer=False)
self.out = torch.nn.Linear(config.hidden_size, out_dim)
# Set gradient checkpointing
self.model.encoder.gradient_checkpointing = gradient_checkpointing
self.tokenizer = BertTokenizer.from_pretrained(config_path)
def get_repr(self, texts: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
"""
Compute text representation for the given texts
Args:
texts: A list of strings
batch_size: Batch size for inference
verbose: Whether to print progress
"""
device = next(self.parameters()).device
text_repr = []
if verbose:
iterator = tqdm(range(0, len(texts), batch_size), desc="Computing text embeddings")
else:
iterator = range(0, len(texts), batch_size)
for i in iterator:
text_inputs = self.tokenizer.batch_encode_plus(texts[i: i+batch_size],
return_tensors="pt",
truncation=True,
max_length=512,
padding=True)
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
output = self(text_inputs)
text_repr.append(output)
text_repr = torch.cat(text_repr, dim=0)
return normalize(text_repr, dim=-1)
def forward(self, inputs: dict):
"""
Encode text into text representation
Args:
inputs: A dictionary containing the following keys:
- input_ids: [batch, seq_len]
- attention_mask: [batch, seq_len]
- token_type_ids: [batch, seq_len]
Returns:
text_repr: [batch, text_repr_dim]
"""
reprs = self.model(**inputs).last_hidden_state[:, 0, :]
reprs = self.out(reprs)
return reprs