Spaces:
Runtime error
Runtime error
File size: 7,141 Bytes
db5855f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import torch
import numpy as np
import openvino as ov
from typing import List, Dict
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
def init_past_inputs(model_inputs: List):
"""
Helper function for initialization of past inputs on first inference step
Parameters:
model_inputs (List): list of model inputs
Returns:
pkv (List[ov.Tensor]): list of filled past key values
"""
pkv = []
for input_tensor in model_inputs[4:]:
partial_shape = input_tensor.partial_shape
partial_shape[0] = 1
partial_shape[2] = 0
pkv.append(ov.Tensor(ov.Type.f32, partial_shape.get_shape()))
return pkv
def postprocess_text_decoder_outputs(output: Dict):
"""
Helper function for rearranging model outputs and wrapping to CausalLMOutputWithCrossAttentions
Parameters:
output (Dict): dictionary with model output
Returns
wrapped_outputs (CausalLMOutputWithCrossAttentions): outputs wrapped to CausalLMOutputWithCrossAttentions format
"""
logits = torch.from_numpy(output[0])
past_kv = list(output.values())[1:]
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=logits,
past_key_values=past_kv,
hidden_states=None,
attentions=None,
cross_attentions=None,
)
def text_decoder_forward(
ov_text_decoder_with_past: ov.CompiledModel,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
past_key_values: List[ov.Tensor],
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
**kwargs
):
"""
Inference function for text_decoder in one generation step
Parameters:
input_ids (torch.Tensor): input token ids
attention_mask (torch.Tensor): attention mask for input token ids
past_key_values (List[ov.Tensor] list of cached decoder hidden states from previous step
encoder_hidden_states (torch.Tensor): encoder (vision or text) hidden states
encoder_attention_mask (torch.Tensor): attnetion mask for encoder hidden states
Returns
model outputs (CausalLMOutputWithCrossAttentions): model prediction wrapped to CausalLMOutputWithCrossAttentions class including predicted logits and hidden states for caching
"""
inputs = [input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask]
if past_key_values is None:
inputs.extend(init_past_inputs(ov_text_decoder_with_past.inputs))
else:
inputs.extend(past_key_values)
outputs = ov_text_decoder_with_past(inputs)
return postprocess_text_decoder_outputs(outputs)
class OVBlipModel:
"""
Model class for inference BLIP model with OpenVINO
"""
def __init__(
self,
config,
decoder_start_token_id: int,
vision_model,
text_encoder,
text_decoder,
):
"""
Initialization class parameters
"""
self.vision_model = vision_model
self.vision_model_out = vision_model.output(0)
self.text_encoder = text_encoder
self.text_encoder_out = text_encoder.output(0)
self.text_decoder = text_decoder
self.config = config
self.decoder_start_token_id = decoder_start_token_id
self.decoder_input_ids = config.text_config.bos_token_id
def generate_answer(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs):
"""
Visual Question Answering prediction
Parameters:
pixel_values (torch.Tensor): preprocessed image pixel values
input_ids (torch.Tensor): question token ids after tokenization
attention_mask (torch.Tensor): attention mask for question tokens
Retruns:
generation output (torch.Tensor): tensor which represents sequence of generated answer token ids
"""
image_embed = self.vision_model(pixel_values.detach().numpy())[self.vision_model_out]
image_attention_mask = np.ones(image_embed.shape[:-1], dtype=int)
if isinstance(input_ids, list):
input_ids = torch.LongTensor(input_ids)
question_embeds = self.text_encoder(
[
input_ids.detach().numpy(),
attention_mask.detach().numpy(),
image_embed,
image_attention_mask,
]
)[self.text_encoder_out]
question_attention_mask = np.ones(question_embeds.shape[:-1], dtype=int)
bos_ids = np.full((question_embeds.shape[0], 1), fill_value=self.decoder_start_token_id)
outputs = self.text_decoder.generate(
input_ids=torch.from_numpy(bos_ids),
eos_token_id=self.config.text_config.sep_token_id,
pad_token_id=self.config.text_config.pad_token_id,
encoder_hidden_states=torch.from_numpy(question_embeds),
encoder_attention_mask=torch.from_numpy(question_attention_mask),
**generate_kwargs,
)
return outputs
def generate_caption(self, pixel_values: torch.Tensor, input_ids: torch.Tensor = None, attention_mask: torch.Tensor = None, **generate_kwargs):
"""
Image Captioning prediction
Parameters:
pixel_values (torch.Tensor): preprocessed image pixel values
input_ids (torch.Tensor, *optional*, None): pregenerated caption token ids after tokenization, if provided caption generation continue provided text
attention_mask (torch.Tensor): attention mask for caption tokens, used only if input_ids provided
Retruns:
generation output (torch.Tensor): tensor which represents sequence of generated caption token ids
"""
batch_size = pixel_values.shape[0]
image_embeds = self.vision_model(pixel_values.detach().numpy())[self.vision_model_out]
image_attention_mask = torch.ones(image_embeds.shape[:-1], dtype=torch.long)
if isinstance(input_ids, list):
input_ids = torch.LongTensor(input_ids)
elif input_ids is None:
input_ids = torch.LongTensor(
[
[
self.config.text_config.bos_token_id,
self.config.text_config.eos_token_id,
]
]
).repeat(batch_size, 1)
input_ids[:, 0] = self.config.text_config.bos_token_id
attention_mask = attention_mask[:, :-1] if attention_mask is not None else None
outputs = self.text_decoder.generate(
input_ids=input_ids[:, :-1],
eos_token_id=self.config.text_config.sep_token_id,
pad_token_id=self.config.text_config.pad_token_id,
attention_mask=attention_mask,
encoder_hidden_states=torch.from_numpy(image_embeds),
encoder_attention_mask=image_attention_mask,
**generate_kwargs,
)
return outputs
|