|
from pathlib import Path |
|
from typing import Optional, Tuple, Union, List |
|
import openvino as ov |
|
import numpy as np |
|
import torch |
|
from transformers import AutoConfig |
|
from transformers.generation import GenerationConfig, GenerationMixin |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
|
core = ov.Core() |
|
|
|
|
|
LANGUAGE_MODEL_NAME = "openvino_language_model.xml" |
|
VISION_TOWER_HIGH_NAME = "openvino_vision_tower_high_model.xml" |
|
TEXT_EMBEDDING_NAME = "openvino_text_embeddings_model.xml" |
|
PROJECTOR_VARY_NAME = "openvino_projector_vary_model.xml" |
|
LM_HAED_NAME = "openvino_lm_head_model.xml" |
|
|
|
|
|
class OvModelForCausalLMWithEmb(GenerationMixin): |
|
def __init__(self, model_dir, device="CPU", config=None, ov_config=None, compile=True) -> None: |
|
self._supports_cache_class = False |
|
self.config = AutoConfig.from_pretrained(model_dir) if config is None else config |
|
self.config.is_decoder = True |
|
self.config.is_encoder_decoder = False |
|
self.generation_config = GenerationConfig.from_model_config(self.config) |
|
model_dir = Path(model_dir) |
|
self.model = core.read_model(model_dir / LANGUAGE_MODEL_NAME) |
|
self.token_emb = core.read_model(model_dir / TEXT_EMBEDDING_NAME) |
|
self.request = None |
|
self.token_emb_request = None |
|
self._device = device.upper() |
|
self.device = torch.device("cpu") |
|
self.ov_config = ov_config |
|
self.next_beam_idx = None |
|
self._past_length = None |
|
self.input_names = [input_t.get_any_name() for input_t in self.model.inputs] |
|
self.main_input_name = "input_ids" |
|
if compile: |
|
self.compile() |
|
|
|
def compile(self): |
|
if self.request is None: |
|
self.request = core.compile_model(self.model, self._device, self.ov_config).create_infer_request() |
|
self._compile_token_emb() |
|
|
|
def _compile_token_emb(self): |
|
if self.token_emb_request is None: |
|
self.token_emb_request = core.compile_model(self.token_emb, self._device, self.ov_config) |
|
|
|
def to(self, device: str): |
|
if isinstance(device, str): |
|
self._device = device.upper() |
|
self.clear_requests() |
|
|
|
return self |
|
|
|
def clear_requests(self): |
|
del self.request |
|
del self.token_emb_request |
|
self.request = None |
|
self.token_emb_request = None |
|
|
|
def embed_tokens(self, input_ids: torch.LongTensor): |
|
self._compile_token_emb() |
|
res = self.token_emb_request(input_ids, share_inputs=True) |
|
return res[0] |
|
|
|
def prepare_inputs( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
**kwargs, |
|
): |
|
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] |
|
|
|
inputs = {} |
|
|
|
if past_key_values is None: |
|
|
|
if self.request is not None: |
|
self.request.reset_state() |
|
|
|
|
|
self.next_beam_idx = np.arange(batch_size, dtype=int) |
|
self._past_length = 0 |
|
past_len = self._get_past_length(past_key_values) |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids if past_key_values is None else input_ids[:, -1:]) |
|
|
|
if hasattr(self.config, "scale_emb"): |
|
inputs_embeds = inputs_embeds * self.config.scale_emb |
|
inputs["inputs_embeds"] = inputs_embeds |
|
|
|
|
|
if "attention_mask" in self.input_names or "position_ids" in self.input_names: |
|
if attention_mask is not None: |
|
attention_mask = np.array(attention_mask) |
|
else: |
|
attention_mask = np.ones((inputs_embeds.shape[0], inputs_embeds.shape[1] + past_len), dtype=int) |
|
|
|
if "attention_mask" in self.input_names: |
|
inputs["attention_mask"] = attention_mask |
|
|
|
if "position_ids" in self.input_names: |
|
if position_ids is not None: |
|
position_ids = np.array(position_ids) |
|
else: |
|
position_ids = np.cumsum(attention_mask, axis=1) - 1 |
|
position_ids[attention_mask == 0] = 1 |
|
if past_key_values: |
|
position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
|
inputs["position_ids"] = position_ids |
|
|
|
if "beam_idx" in self.input_names: |
|
inputs["beam_idx"] = self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int) |
|
|
|
return inputs |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.LongTensor] = None, |
|
**kwargs, |
|
): |
|
self.compile() |
|
|
|
inputs = self.prepare_inputs( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
past_key_values=past_key_values, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
**kwargs, |
|
) |
|
|
|
|
|
self.request.start_async(inputs, share_inputs=True) |
|
self.request.wait() |
|
logits = self.request.get_tensor("logits").data |
|
logits = torch.from_numpy(logits).to(self.device) |
|
past_key_values = ((),) |
|
self._past_length += inputs["inputs_embeds"].shape[1] |
|
|
|
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): |
|
|
|
attention_mask = kwargs.get("attention_mask", None) |
|
use_cache = kwargs.get("use_cache", None) |
|
|
|
if past_key_values is not None: |
|
past_len = self._get_past_length(past_key_values) |
|
|
|
|
|
|
|
|
|
if attention_mask is not None and input_ids is not None and attention_mask.shape[1] > input_ids.shape[1]: |
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_len) :] |
|
|
|
|
|
elif input_ids is not None and past_len < input_ids.shape[1]: |
|
input_ids = input_ids[:, past_len:] |
|
|
|
position_ids = kwargs.get("position_ids", None) |
|
if attention_mask is not None and position_ids is None and "position_ids" in self.input_names: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
if past_key_values and input_ids is not None: |
|
position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
|
model_inputs = { |
|
"input_ids": input_ids, |
|
"past_key_values": past_key_values, |
|
"use_cache": use_cache, |
|
"position_ids": position_ids, |
|
"attention_mask": attention_mask, |
|
"inputs_embeds": inputs_embeds if past_key_values is None else None, |
|
} |
|
|
|
return model_inputs |
|
|
|
def _get_past_length(self, past_key_values=None): |
|
if past_key_values is None: |
|
return 0 |
|
return self._past_length |
|
|
|
|
|
def _reorder_cache(self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: |
|
""" |
|
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or |
|
[`~PreTrainedModel.beam_sample`] is called. |
|
This is required to match `past_key_values` with the correct beam_idx at every generation step. |
|
""" |
|
self.next_beam_idx = np.array(beam_idx) |
|
return past_key_values |
|
|
|
def can_generate(self): |
|
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" |
|
|
|
return True |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.forward(*args, **kwargs) |
|
|
|
|
|
class OVGotOcrModel(GenerationMixin): |
|
def __init__(self, model_dir, device, ov_config=None, compression_configuration=None): |
|
model_dir = Path(model_dir) |
|
self.config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) |
|
self.generation_config = GenerationConfig.from_model_config(self.config) |
|
self.vision_tower_high = core.compile_model(model_dir / VISION_TOWER_HIGH_NAME, device, ov_config) |
|
self.mm_projector_vary = core.compile_model(model_dir / PROJECTOR_VARY_NAME, device, ov_config) |
|
self.embed_tokens = core.compile_model(model_dir / TEXT_EMBEDDING_NAME, device) |
|
self.lm_head = core.compile_model(model_dir / LM_HAED_NAME, device) |
|
self.language_model = OvModelForCausalLMWithEmb(model_dir, device, self.config, ov_config) |
|
self.main_input_name = "input_ids" |
|
self.device = torch.device("cpu") |
|
self._supports_cache_class = False |
|
self.next_beam_idx = None |
|
self._past_length = None |
|
self.first = True |
|
self.im_start_token = self.config.im_start_token |
|
|
|
def can_generate(self): |
|
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" |
|
return True |
|
|
|
def __call__(self, *args, **kwargs) -> CausalLMOutputWithPast: |
|
return self.forward( |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
def _reorder_cache(self, *args, **kwargs) -> Tuple[Tuple[torch.Tensor]]: |
|
""" |
|
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or |
|
[`~PreTrainedModel.beam_sample`] is called. |
|
This is required to match `past_key_values` with the correct beam_idx at every generation step. |
|
""" |
|
return self.language_model._reorder_cache(*args, **kwargs) |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
|
): |
|
|
|
if past_key_values is not None: |
|
cache_length = past_length = self.language_model._get_past_length(past_key_values) |
|
max_cache_length = None |
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: |
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] |
|
|
|
|
|
elif past_length < input_ids.shape[1]: |
|
input_ids = input_ids[:, past_length:] |
|
|
|
|
|
|
|
if ( |
|
max_cache_length is not None |
|
and attention_mask is not None |
|
and cache_length + input_ids.shape[1] > max_cache_length |
|
): |
|
attention_mask = attention_mask[:, -max_cache_length:] |
|
|
|
position_ids = kwargs.get("position_ids", None) |
|
if attention_mask is not None and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
if past_key_values: |
|
position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
else: |
|
model_inputs = {"input_ids": input_ids} |
|
|
|
model_inputs.update( |
|
{ |
|
"position_ids": position_ids, |
|
"past_key_values": past_key_values, |
|
"use_cache": kwargs.get("use_cache"), |
|
"attention_mask": attention_mask, |
|
"images": kwargs.get("images", None), |
|
} |
|
) |
|
return model_inputs |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
images: Optional[torch.FloatTensor] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = torch.from_numpy(self.language_model.embed_tokens(input_ids)) |
|
|
|
if self.vision_tower_high is not None and (input_ids.shape[1] != 1) and images is not None: |
|
use_im_start_end = getattr(self.config, "use_im_start_end", -1) |
|
|
|
vision_select_layer = getattr(self.config, "vision_select_layer", -1) |
|
im_patch_token = getattr(self.config, "im_patch_token", -1) |
|
im_start_token = getattr(self.config, "im_start_token", -1) |
|
im_end_token = getattr(self.config, "im_end_token", -1) |
|
freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False) |
|
|
|
im_patch_token = 151859 |
|
|
|
im_start_token = 151857 |
|
|
|
im_end_token = 151858 |
|
|
|
image_features = [] |
|
|
|
for image in images: |
|
P, C, H, W = image.shape |
|
if P == 1: |
|
with torch.set_grad_enabled(False): |
|
cnn_feature = self.vision_tower_high(image)[0] |
|
cnn_feature = torch.from_numpy(cnn_feature).flatten(2).permute(0, 2, 1).numpy() |
|
image_feature = self.mm_projector_vary(cnn_feature)[0] |
|
image_features.append(torch.from_numpy(image_feature)) |
|
|
|
else: |
|
image_patches = torch.unbind(image) |
|
image_patches_features = [] |
|
for image_patch in image_patches: |
|
image_p = torch.stack([image_patch]) |
|
|
|
with torch.set_grad_enabled(False): |
|
cnn_feature_p = self.vision_tower_high(image_p)[0] |
|
cnn_feature_p = torch.from_numpy(cnn_feature_p).flatten(2).permute(0, 2, 1).numpy() |
|
image_feature_p = self.mm_projector_vary(cnn_feature_p)[0] |
|
image_patches_features.append(torch.from_numpy(image_feature_p)) |
|
image_feature = torch.cat(image_patches_features, dim=1) |
|
image_features.append(image_feature) |
|
|
|
|
|
dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) |
|
dummy_image_features = dummy_image_features_2 |
|
use_im_start_end = True |
|
new_input_embeds = [] |
|
for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features): |
|
if (cur_input_ids == im_patch_token).sum() == 0: |
|
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() |
|
new_input_embeds.append(cur_input_embeds) |
|
continue |
|
|
|
if use_im_start_end: |
|
if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum(): |
|
raise ValueError("The number of image start tokens and image end tokens should be the same.") |
|
|
|
image_start_tokens = torch.where(cur_input_ids == im_start_token)[0] |
|
for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features): |
|
per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device) |
|
num_patches = per_cur_image_features.shape[0] |
|
|
|
if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token: |
|
raise ValueError("The image end token should follow the image start token.") |
|
|
|
cur_input_embeds = torch.cat( |
|
( |
|
cur_input_embeds[:image_start_token_pos+1], |
|
per_cur_image_features, |
|
cur_input_embeds[image_start_token_pos + num_patches + 1:] |
|
), |
|
dim=0 |
|
) |
|
|
|
|
|
new_input_embeds.append(cur_input_embeds) |
|
else: |
|
raise NotImplementedError |
|
|
|
inputs_embeds = torch.stack(new_input_embeds, dim=0) |
|
|
|
outputs = self.language_model( |
|
None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=True |
|
) |
|
logits = outputs[0] |
|
logits = self.lm_head(logits[0])[0] |
|
logits = torch.from_numpy(logits).to(self.device) |
|
logits = logits.unsqueeze(0) |
|
|
|
return CausalLMOutputWithPast( |
|
loss=None, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
) |
|
|