|
import warnings |
|
from transformers import T5EncoderModel, T5Tokenizer |
|
import torch |
|
from torch import nn |
|
|
|
|
|
class T5Conditioner(nn.Module): |
|
|
|
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", |
|
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", |
|
"google/flan-t5-xl", "google/flan-t5-xxl"] |
|
MODELS_DIMS = { |
|
"t5-small": 512, |
|
"t5-base": 768, |
|
"t5-large": 1024, |
|
"t5-3b": 1024, |
|
"t5-11b": 1024, |
|
"google/flan-t5-small": 512, |
|
"google/flan-t5-base": 768, |
|
"google/flan-t5-large": 1024, |
|
"google/flan-t5-3b": 1024, |
|
"google/flan-t5-11b": 1024, |
|
} |
|
|
|
def __init__(self, |
|
name, |
|
output_dim, |
|
device, |
|
finetune=False): |
|
print(f'{finetune=}') |
|
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})" |
|
super().__init__() |
|
self.dim = self.MODELS_DIMS[name] |
|
self.output_dim = output_dim |
|
self.output_proj = nn.Linear(self.dim, output_dim) |
|
self.device = device |
|
self.name = name |
|
|
|
self.t5_tokenizer = T5Tokenizer.from_pretrained(name) |
|
t5 = T5EncoderModel.from_pretrained(name).eval() |
|
if finetune: |
|
self.t5 = t5 |
|
else: |
|
|
|
|
|
self.__dict__['t5'] = t5.to(device) |
|
|
|
|
|
def tokenize(self, x): |
|
|
|
entries = [xi if xi is not None else "" for xi in x] |
|
|
|
|
|
inputs = self.t5_tokenizer(entries, |
|
return_tensors='pt', |
|
padding=True).to(self.device) |
|
|
|
return inputs |
|
|
|
def forward(self, descriptions): |
|
|
|
d = self.tokenize(descriptions) |
|
|
|
with torch.no_grad(): |
|
embeds = self.t5(input_ids=d['input_ids'], |
|
attention_mask=d['attention_mask'] |
|
).last_hidden_state |
|
embeds = self.output_proj(embeds.to(self.output_proj.weight)) |
|
embeds = (embeds * d['attention_mask'].unsqueeze(-1)) |
|
|
|
return embeds |
|
|