import warnings from transformers import T5EncoderModel, T5Tokenizer # type: ignore import torch from torch import nn class T5Conditioner(nn.Module): MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl"] MODELS_DIMS = { "t5-small": 512, "t5-base": 768, "t5-large": 1024, "t5-3b": 1024, "t5-11b": 1024, "google/flan-t5-small": 512, "google/flan-t5-base": 768, "google/flan-t5-large": 1024, "google/flan-t5-3b": 1024, "google/flan-t5-11b": 1024, } def __init__(self, name, output_dim, device, finetune=False): print(f'{finetune=}') assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})" super().__init__() self.dim = self.MODELS_DIMS[name] self.output_dim = output_dim self.output_proj = nn.Linear(self.dim, output_dim) self.device = device self.name = name self.t5_tokenizer = T5Tokenizer.from_pretrained(name) t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune) if finetune: self.t5 = t5 else: # this makes sure that the t5 models is not part # of the saved checkpoint self.__dict__['t5'] = t5.to(device) def tokenize(self, x): entries = [xi if xi is not None else "" for xi in x] inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device) return inputs # 'input_ids' 'attentio mask' def forward(self, descriptions): d = self.tokenize(descriptions) with torch.no_grad(): embeds = self.t5(input_ids=d['input_ids'], attention_mask=d['attention_mask'] ).last_hidden_state # no kvcache for txt conditioning embeds = self.output_proj(embeds.to(self.output_proj.weight)) embeds = (embeds * d['attention_mask'].unsqueeze(-1)) return embeds # , d['attention_mask']