File size: 2,366 Bytes
d72b2c3 531e776 d72b2c3 d912185 d72b2c3 86b9ce4 d912185 d72b2c3 d912185 d72b2c3 86b9ce4 d72b2c3 86b9ce4 d72b2c3 86b9ce4 d72b2c3 86b9ce4 d912185 86b9ce4 d72b2c3 86b9ce4 d72b2c3 86b9ce4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import warnings
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
import torch
from torch import nn
class T5Conditioner(nn.Module):
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
"google/flan-t5-xl", "google/flan-t5-xxl"]
MODELS_DIMS = {
"t5-small": 512,
"t5-base": 768,
"t5-large": 1024,
"t5-3b": 1024,
"t5-11b": 1024,
"google/flan-t5-small": 512,
"google/flan-t5-base": 768,
"google/flan-t5-large": 1024,
"google/flan-t5-3b": 1024,
"google/flan-t5-11b": 1024,
}
def __init__(self,
name,
output_dim,
device,
finetune=False):
print(f'{finetune=}')
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
super().__init__()
self.dim = self.MODELS_DIMS[name]
self.output_dim = output_dim
self.output_proj = nn.Linear(self.dim, output_dim)
self.device = device
self.name = name
self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
if finetune:
self.t5 = t5
else:
# this makes sure that the t5 models is not part
# of the saved checkpoint
self.__dict__['t5'] = t5.to(device)
def tokenize(self, x):
entries = [xi if xi is not None else "" for xi in x]
inputs = self.t5_tokenizer(entries,
return_tensors='pt',
padding=True).to(self.device)
return inputs # 'input_ids' 'attentio mask'
def forward(self, descriptions):
d = self.tokenize(descriptions)
with torch.no_grad():
embeds = self.t5(input_ids=d['input_ids'],
attention_mask=d['attention_mask']
).last_hidden_state # no kvcache for txt conditioning
embeds = self.output_proj(embeds.to(self.output_proj.weight))
embeds = (embeds * d['attention_mask'].unsqueeze(-1))
return embeds # , d['attention_mask']
|