Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
import openvino as ov | |
from typing import List, Dict | |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions | |
def init_past_inputs(model_inputs: List): | |
""" | |
Helper function for initialization of past inputs on first inference step | |
Parameters: | |
model_inputs (List): list of model inputs | |
Returns: | |
pkv (List[ov.Tensor]): list of filled past key values | |
""" | |
pkv = [] | |
for input_tensor in model_inputs[4:]: | |
partial_shape = input_tensor.partial_shape | |
partial_shape[0] = 1 | |
partial_shape[2] = 0 | |
pkv.append(ov.Tensor(ov.Type.f32, partial_shape.get_shape())) | |
return pkv | |
def postprocess_text_decoder_outputs(output: Dict): | |
""" | |
Helper function for rearranging model outputs and wrapping to CausalLMOutputWithCrossAttentions | |
Parameters: | |
output (Dict): dictionary with model output | |
Returns | |
wrapped_outputs (CausalLMOutputWithCrossAttentions): outputs wrapped to CausalLMOutputWithCrossAttentions format | |
""" | |
logits = torch.from_numpy(output[0]) | |
past_kv = list(output.values())[1:] | |
return CausalLMOutputWithCrossAttentions( | |
loss=None, | |
logits=logits, | |
past_key_values=past_kv, | |
hidden_states=None, | |
attentions=None, | |
cross_attentions=None, | |
) | |
def text_decoder_forward( | |
ov_text_decoder_with_past: ov.CompiledModel, | |
input_ids: torch.Tensor, | |
attention_mask: torch.Tensor, | |
past_key_values: List[ov.Tensor], | |
encoder_hidden_states: torch.Tensor, | |
encoder_attention_mask: torch.Tensor, | |
**kwargs | |
): | |
""" | |
Inference function for text_decoder in one generation step | |
Parameters: | |
input_ids (torch.Tensor): input token ids | |
attention_mask (torch.Tensor): attention mask for input token ids | |
past_key_values (List[ov.Tensor] list of cached decoder hidden states from previous step | |
encoder_hidden_states (torch.Tensor): encoder (vision or text) hidden states | |
encoder_attention_mask (torch.Tensor): attnetion mask for encoder hidden states | |
Returns | |
model outputs (CausalLMOutputWithCrossAttentions): model prediction wrapped to CausalLMOutputWithCrossAttentions class including predicted logits and hidden states for caching | |
""" | |
inputs = [input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask] | |
if past_key_values is None: | |
inputs.extend(init_past_inputs(ov_text_decoder_with_past.inputs)) | |
else: | |
inputs.extend(past_key_values) | |
outputs = ov_text_decoder_with_past(inputs) | |
return postprocess_text_decoder_outputs(outputs) | |
class OVBlipModel: | |
""" | |
Model class for inference BLIP model with OpenVINO | |
""" | |
def __init__( | |
self, | |
config, | |
decoder_start_token_id: int, | |
vision_model, | |
text_encoder, | |
text_decoder, | |
): | |
""" | |
Initialization class parameters | |
""" | |
self.vision_model = vision_model | |
self.vision_model_out = vision_model.output(0) | |
self.text_encoder = text_encoder | |
self.text_encoder_out = text_encoder.output(0) | |
self.text_decoder = text_decoder | |
self.config = config | |
self.decoder_start_token_id = decoder_start_token_id | |
self.decoder_input_ids = config.text_config.bos_token_id | |
def generate_answer(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs): | |
""" | |
Visual Question Answering prediction | |
Parameters: | |
pixel_values (torch.Tensor): preprocessed image pixel values | |
input_ids (torch.Tensor): question token ids after tokenization | |
attention_mask (torch.Tensor): attention mask for question tokens | |
Retruns: | |
generation output (torch.Tensor): tensor which represents sequence of generated answer token ids | |
""" | |
image_embed = self.vision_model(pixel_values.detach().numpy())[self.vision_model_out] | |
image_attention_mask = np.ones(image_embed.shape[:-1], dtype=int) | |
if isinstance(input_ids, list): | |
input_ids = torch.LongTensor(input_ids) | |
question_embeds = self.text_encoder( | |
[ | |
input_ids.detach().numpy(), | |
attention_mask.detach().numpy(), | |
image_embed, | |
image_attention_mask, | |
] | |
)[self.text_encoder_out] | |
question_attention_mask = np.ones(question_embeds.shape[:-1], dtype=int) | |
bos_ids = np.full((question_embeds.shape[0], 1), fill_value=self.decoder_start_token_id) | |
outputs = self.text_decoder.generate( | |
input_ids=torch.from_numpy(bos_ids), | |
eos_token_id=self.config.text_config.sep_token_id, | |
pad_token_id=self.config.text_config.pad_token_id, | |
encoder_hidden_states=torch.from_numpy(question_embeds), | |
encoder_attention_mask=torch.from_numpy(question_attention_mask), | |
**generate_kwargs, | |
) | |
return outputs | |
def generate_caption(self, pixel_values: torch.Tensor, input_ids: torch.Tensor = None, attention_mask: torch.Tensor = None, **generate_kwargs): | |
""" | |
Image Captioning prediction | |
Parameters: | |
pixel_values (torch.Tensor): preprocessed image pixel values | |
input_ids (torch.Tensor, *optional*, None): pregenerated caption token ids after tokenization, if provided caption generation continue provided text | |
attention_mask (torch.Tensor): attention mask for caption tokens, used only if input_ids provided | |
Retruns: | |
generation output (torch.Tensor): tensor which represents sequence of generated caption token ids | |
""" | |
batch_size = pixel_values.shape[0] | |
image_embeds = self.vision_model(pixel_values.detach().numpy())[self.vision_model_out] | |
image_attention_mask = torch.ones(image_embeds.shape[:-1], dtype=torch.long) | |
if isinstance(input_ids, list): | |
input_ids = torch.LongTensor(input_ids) | |
elif input_ids is None: | |
input_ids = torch.LongTensor( | |
[ | |
[ | |
self.config.text_config.bos_token_id, | |
self.config.text_config.eos_token_id, | |
] | |
] | |
).repeat(batch_size, 1) | |
input_ids[:, 0] = self.config.text_config.bos_token_id | |
attention_mask = attention_mask[:, :-1] if attention_mask is not None else None | |
outputs = self.text_decoder.generate( | |
input_ids=input_ids[:, :-1], | |
eos_token_id=self.config.text_config.sep_token_id, | |
pad_token_id=self.config.text_config.pad_token_id, | |
attention_mask=attention_mask, | |
encoder_hidden_states=torch.from_numpy(image_embeds), | |
encoder_attention_mask=image_attention_mask, | |
**generate_kwargs, | |
) | |
return outputs | |