John6666's picture
Upload 351 files
e84842d verified
raw
history blame
3.11 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
Based on facebookresearch code base
https://github.com/facebookresearch/FiD
"""
import torch
import torch.nn as nn
from lavis.common.registry import registry
from lavis.models.base_model import BaseModel
from lavis.common.utils import get_abs_path
from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration
@registry.register_model("pnp_unifiedqav2_fid")
class PNPUnifiedQAv2FiD(T5ForConditionalGeneration, BaseModel):
PRETRAINED_MODEL_CONFIG_DICT = {}
def __init__(self, config, model_path):
super().__init__(config)
self.tokenizer = T5Tokenizer.from_pretrained(model_path)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
if input_ids != None:
if input_ids.dim() == 3:
self.encoder.num_contexts = input_ids.size(1)
input_ids = input_ids.view(input_ids.size(0), -1)
if attention_mask != None:
attention_mask = attention_mask.view(attention_mask.size(0), -1)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs
)
def generate(self, input_ids, attention_mask, num_beams=1, min_length=0, max_length=20):
self.encoder.num_contexts = input_ids.size(1)
return super().generate(
input_ids=input_ids.view(input_ids.size(0), -1),
attention_mask=attention_mask.view(attention_mask.size(0), -1),
num_beams=num_beams,
min_length=min_length,
max_length=max_length
)
def load_unifiedqa(self, state_dict):
self.load_state_dict(state_dict)
self.encoder = T5EncoderWrapper(self.encoder)
@classmethod
def from_config(cls, cfg):
model_path = cfg.get('pretrained')
t5_config_path = get_abs_path(cfg.get("t5_config_path"))
t5_config = T5Config.from_json_file(t5_config_path)
model = cls(t5_config, model_path)
model.load_unifiedqa(T5ForConditionalGeneration.from_pretrained(model_path).state_dict())
return model
class T5EncoderWrapper(torch.nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
self.block = self.encoder.block
self.parallelize = self.encoder.parallelize
self.main_input_name = encoder.main_input_name
def forward(self, input_ids=None, attention_mask=None, **kwargs):
bsz, total_length = input_ids.shape
context_length = total_length // self.num_contexts
input_ids = input_ids.view(bsz*self.num_contexts, context_length)
attention_mask = attention_mask.view(bsz*self.num_contexts, context_length)
outputs = self.encoder(input_ids, attention_mask, **kwargs)
outputs = (outputs[0].view(bsz, self.num_contexts*context_length, -1), ) + outputs[1:]
return outputs