File size: 3,129 Bytes
52da96f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch

from tqdm import tqdm
from torch.nn.functional import normalize
from transformers import BertConfig, BertModel, BertTokenizer


class TextEncoder(torch.nn.Module):
    def __init__(self,

                 config_path: str,

                 out_dim: int,

                 load_pretrained: bool = True,

                 gradient_checkpointing: bool = False):
        """

        Args:

            config_path: Path to the config file

            

            out_dim: Output dimension of the text representation

            

            load_pretrained: Whether to load pretrained weights

            

            gradient_checkpointing: Whether to enable gradient checkpointing

        """
        super().__init__()
        config = BertConfig.from_pretrained(config_path)
        if load_pretrained:
            self.model = BertModel.from_pretrained(config_path, add_pooling_layer=False)
        else:
            self.model = BertModel(config, add_pooling_layer=False)
        self.out = torch.nn.Linear(config.hidden_size, out_dim)
        
        # Set gradient checkpointing
        self.model.encoder.gradient_checkpointing = gradient_checkpointing
        
        self.tokenizer = BertTokenizer.from_pretrained(config_path)
    
    def get_repr(self, texts: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
        """

        Compute text representation for the given texts

        Args:

            texts: A list of strings

            batch_size: Batch size for inference

            verbose: Whether to print progress

        """
        device = next(self.parameters()).device
        
        text_repr = []
        if verbose:
            iterator = tqdm(range(0, len(texts), batch_size), desc="Computing text embeddings")
        else:
            iterator = range(0, len(texts), batch_size)
        
        for i in iterator:
            text_inputs = self.tokenizer.batch_encode_plus(texts[i: i+batch_size],
                                                           return_tensors="pt",
                                                           truncation=True,
                                                           max_length=512,
                                                           padding=True)
            text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
            output = self(text_inputs)
            
            text_repr.append(output)
        
        text_repr = torch.cat(text_repr, dim=0)
        return normalize(text_repr, dim=-1)
    
    def forward(self, inputs: dict):
        """

        Encode text into text representation

        Args:

            inputs: A dictionary containing the following keys:

                - input_ids: [batch, seq_len]

                - attention_mask: [batch, seq_len]

                - token_type_ids: [batch, seq_len]



        Returns:

            text_repr: [batch, text_repr_dim]

        """
        reprs = self.model(**inputs).last_hidden_state[:, 0, :]
        reprs = self.out(reprs)
        return reprs