Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
Based on facebookresearch code base | |
https://github.com/facebookresearch/FiD | |
""" | |
import torch | |
import torch.nn as nn | |
from lavis.common.registry import registry | |
from lavis.models.base_model import BaseModel | |
from lavis.common.utils import get_abs_path | |
from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration | |
class PNPUnifiedQAv2FiD(T5ForConditionalGeneration, BaseModel): | |
PRETRAINED_MODEL_CONFIG_DICT = {} | |
def __init__(self, config, model_path): | |
super().__init__(config) | |
self.tokenizer = T5Tokenizer.from_pretrained(model_path) | |
def forward(self, input_ids=None, attention_mask=None, **kwargs): | |
if input_ids != None: | |
if input_ids.dim() == 3: | |
self.encoder.num_contexts = input_ids.size(1) | |
input_ids = input_ids.view(input_ids.size(0), -1) | |
if attention_mask != None: | |
attention_mask = attention_mask.view(attention_mask.size(0), -1) | |
return super().forward( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
**kwargs | |
) | |
def generate(self, input_ids, attention_mask, num_beams=1, min_length=0, max_length=20): | |
self.encoder.num_contexts = input_ids.size(1) | |
return super().generate( | |
input_ids=input_ids.view(input_ids.size(0), -1), | |
attention_mask=attention_mask.view(attention_mask.size(0), -1), | |
num_beams=num_beams, | |
min_length=min_length, | |
max_length=max_length | |
) | |
def load_unifiedqa(self, state_dict): | |
self.load_state_dict(state_dict) | |
self.encoder = T5EncoderWrapper(self.encoder) | |
def from_config(cls, cfg): | |
model_path = cfg.get('pretrained') | |
t5_config_path = get_abs_path(cfg.get("t5_config_path")) | |
t5_config = T5Config.from_json_file(t5_config_path) | |
model = cls(t5_config, model_path) | |
model.load_unifiedqa(T5ForConditionalGeneration.from_pretrained(model_path).state_dict()) | |
return model | |
class T5EncoderWrapper(torch.nn.Module): | |
def __init__(self, encoder): | |
super().__init__() | |
self.encoder = encoder | |
self.block = self.encoder.block | |
self.parallelize = self.encoder.parallelize | |
self.main_input_name = encoder.main_input_name | |
def forward(self, input_ids=None, attention_mask=None, **kwargs): | |
bsz, total_length = input_ids.shape | |
context_length = total_length // self.num_contexts | |
input_ids = input_ids.view(bsz*self.num_contexts, context_length) | |
attention_mask = attention_mask.view(bsz*self.num_contexts, context_length) | |
outputs = self.encoder(input_ids, attention_mask, **kwargs) | |
outputs = (outputs[0].view(bsz, self.num_contexts*context_length, -1), ) + outputs[1:] | |
return outputs |