Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import torch | |
from mmengine.model import BaseModule | |
from mmpretrain.registry import MODELS | |
class SeqGenerationHead(BaseModule): | |
"""Generation head for multi-modal pre-trained task, adopted by BLIP. | |
Normally used for generation task. | |
Args: | |
decoder (dict): Decoder for blip generation head. | |
init_cfg (dict, optional): the config to control the initialization. | |
Defaults to None. | |
""" | |
def __init__( | |
self, | |
decoder: dict, | |
ignore_index=-100, | |
loss: dict = dict(type='LabelSmoothLoss', label_smooth_val=0.1), | |
init_cfg: Optional[dict] = None, | |
) -> None: | |
super(SeqGenerationHead, self).__init__(init_cfg=init_cfg) | |
self.decoder = MODELS.build(decoder) | |
self.loss_fn = MODELS.build(loss) | |
self.ignore_index = ignore_index | |
def forward(self, input_ids: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
encoder_attention_mask: torch.Tensor, labels: torch.Tensor): | |
"""Forward to get decoder output. | |
Args: | |
input_ids (torch.Tensor): The tokenized input text tensor. | |
encoder_hidden_states (torch.Tensor): Hidden states from image | |
embeddings. | |
encoder_attention_mask (torch.Tensor): Image embeddings hidden | |
states attention mask. | |
labels (torch.Tensor): Decoder target for calculate loss. | |
Returns: | |
dict[str, Tensor]: a dictionary of decoder outputs. | |
""" | |
decoder_out = self.decoder( | |
input_ids=input_ids, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
labels=labels, | |
return_dict=True, | |
) | |
return decoder_out | |
def loss(self, input_ids, encoder_hidden_states, encoder_attention_mask, | |
labels): | |
"""Calculate losses from the extracted features. | |
Args: | |
input_ids (torch.Tensor): The tokenized input text tensor. | |
encoder_hidden_states (torch.Tensor): Hidden states from image | |
embeddings. | |
encoder_attention_mask (torch.Tensor): Image embeddings hidden | |
states attention mask. | |
labels (torch.Tensor): Decoder target for calculate loss. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components. | |
""" | |
decoder_out = self( | |
input_ids=input_ids, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
labels=labels, | |
) | |
prediction_scores = decoder_out['logits'] | |
# we are doing next-token prediction; | |
# shift prediction scores and input ids by one | |
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() | |
labels = labels[:, 1:].contiguous() | |
vocab_size = prediction_scores.shape[-1] | |
# mask ignored index | |
if (labels == self.ignore_index).any(): | |
labels = labels.view(-1).clone() | |
ignore_mask = (labels == self.ignore_index) | |
labels.masked_fill_(ignore_mask, 0) | |
weight = torch.logical_not(ignore_mask) | |
avg_factor = max(weight.sum(), 1) | |
else: | |
weight = None | |
avg_factor = labels.size(0) | |
lm_loss = self.loss_fn( | |
shifted_prediction_scores.view(-1, vocab_size), | |
labels, | |
weight=weight, | |
avg_factor=avg_factor, | |
) | |
losses = { | |
'seq_gen_lm_loss': lm_loss, | |
} | |
return losses | |
def predict(self, | |
input_ids, | |
encoder_hidden_states, | |
sep_token_id, | |
pad_token_id, | |
use_nucleus_sampling=False, | |
num_beams=3, | |
max_length=20, | |
min_length=2, | |
top_p=0.9, | |
repetition_penalty=1.0, | |
**kwargs): | |
"""Decoder prediction method. | |
Args: | |
input_ids (torch.Tensor): The tokenized input text tensor. | |
encoder_hidden_states (torch.Tensor): Hidden states from image | |
embeddings. | |
sep_token_id (int): Tokenid of separation token. | |
pad_token_id (int): Tokenid of pad token. | |
use_nucleus_sampling (bool): Whether to use nucleus sampling in | |
prediction. Defaults to False. | |
num_beams (int): Number of beams used in predition. | |
Defaults to 3. | |
max_length (int): Max length of generated text in predition. | |
Defaults to 20. | |
min_length (int): Min length of generated text in predition. | |
Defaults to 20. | |
top_p (float): | |
If < 1.0, only keep the top tokens with cumulative probability | |
>= top_p (nucleus filtering). Defaults to 0.9. | |
repetition_penalty (float): The parameter for repetition penalty. | |
Defaults to 1.0. | |
**kwarg: Other arguments that might used in generation. | |
Returns: | |
dict[str, Tensor]: a dictionary of generation outputs. | |
""" | |
device = encoder_hidden_states.device | |
# TODO: In old version of transformers | |
# Additional repeat interleave of hidden states should be add here. | |
image_atts = torch.ones( | |
encoder_hidden_states.size()[:-1], dtype=torch.long).to(device) | |
model_kwargs = { | |
'encoder_hidden_states': encoder_hidden_states, | |
'encoder_attention_mask': image_atts, | |
} | |
model_kwargs.update(kwargs) | |
if use_nucleus_sampling: | |
# nucleus sampling | |
outputs = self.decoder.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
min_length=min_length, | |
do_sample=True, | |
top_p=top_p, | |
num_return_sequences=1, | |
eos_token_id=sep_token_id, | |
pad_token_id=pad_token_id, | |
repetition_penalty=1.1, | |
**model_kwargs) | |
else: | |
# beam search | |
outputs = self.decoder.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
min_length=min_length, | |
num_beams=num_beams, | |
eos_token_id=sep_token_id, | |
pad_token_id=pad_token_id, | |
repetition_penalty=repetition_penalty, | |
**model_kwargs) | |
return outputs | |