File size: 2,366 Bytes
d72b2c3
531e776
d72b2c3
 
 
 
d912185
d72b2c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86b9ce4
 
 
 
d912185
 
d72b2c3
d912185
 
 
 
d72b2c3
 
86b9ce4
 
 
d72b2c3
 
 
 
 
 
 
 
86b9ce4
 
 
d72b2c3
 
86b9ce4
 
 
 
 
d72b2c3
86b9ce4
 
 
 
d912185
86b9ce4
 
 
d72b2c3
86b9ce4
d72b2c3
86b9ce4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import warnings
from transformers import T5EncoderModel, T5Tokenizer  # type: ignore
import torch
from torch import nn


class T5Conditioner(nn.Module):

    MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
              "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
              "google/flan-t5-xl", "google/flan-t5-xxl"]
    MODELS_DIMS = {
        "t5-small": 512,
        "t5-base": 768,
        "t5-large": 1024,
        "t5-3b": 1024,
        "t5-11b": 1024,
        "google/flan-t5-small": 512,
        "google/flan-t5-base": 768,
        "google/flan-t5-large": 1024,
        "google/flan-t5-3b": 1024,
        "google/flan-t5-11b": 1024,
    }

    def __init__(self,
                 name,
                 output_dim,
                 device,
                 finetune=False):
        print(f'{finetune=}')
        assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
        super().__init__()
        self.dim = self.MODELS_DIMS[name]
        self.output_dim = output_dim
        self.output_proj = nn.Linear(self.dim, output_dim)
        self.device = device
        self.name = name
        
        self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
        t5 = T5EncoderModel.from_pretrained(name).eval()  #.train(mode=finetune)
        if finetune:
            self.t5 = t5
        else:
            # this makes sure that the t5 models is not part
            # of the saved checkpoint
            self.__dict__['t5'] = t5.to(device)


    def tokenize(self, x):
        
        entries = [xi if xi is not None else "" for xi in x]


        inputs = self.t5_tokenizer(entries,
                                   return_tensors='pt', 
                                   padding=True).to(self.device)
        
        return inputs  # 'input_ids' 'attentio mask'

    def forward(self, descriptions):
        
        d = self.tokenize(descriptions)
        
        with torch.no_grad():
            embeds = self.t5(input_ids=d['input_ids'],
                             attention_mask=d['attention_mask']
                             ).last_hidden_state  # no kvcache for txt conditioning
        embeds = self.output_proj(embeds.to(self.output_proj.weight))
        embeds = (embeds * d['attention_mask'].unsqueeze(-1))

        return embeds # , d['attention_mask']