|
import io |
|
import logging |
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from torch.nn import MSELoss |
|
from transformers.modeling_outputs import ( |
|
CausalLMOutputWithPast, |
|
) |
|
from typing import List, Optional, Tuple, Union |
|
from torch.cuda.amp import autocast as autocast |
|
from .modeling_base import BaseMLLM |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class InternVideo2_VideoChat2(BaseMLLM): |
|
|
|
def __init__( |
|
self, |
|
config |
|
): |
|
super().__init__(config=config) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
image: Optional[torch.Tensor] = None, |
|
video: Optional[torch.Tensor] = None, |
|
instruction = None, |
|
video_idx = None, |
|
image_idx = None, |
|
): |
|
|
|
|
|
if self.use_vision_regression_loss: |
|
text_embeds, visual, visual_idx = self.pad_text_embeds(input_ids=input_ids, image=image,video=video, return_visual=True, video_idx=video_idx, image_idx=image_idx, instruction = instruction) |
|
else: |
|
text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, return_visual=False, video_idx=video_idx, image_idx=image_idx, instruction = instruction) |
|
|
|
outputs = self.lm( |
|
inputs_embeds=text_embeds, |
|
attention_mask=attention_mask, |
|
labels=labels, |
|
output_hidden_states=True, |
|
return_dict=True, |
|
) |
|
|
|
return outputs |
|
|
|
def pad_text_embeds( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
image: Optional[torch.Tensor] = None, |
|
video: Optional[torch.Tensor] = None, |
|
image_idx = None, |
|
video_idx = None, |
|
return_visual: bool = False, |
|
instruction = None, |
|
): |
|
|
|
text_embeds = self.lm.get_input_embeddings()(input_ids.long()).detach() |
|
|
|
visual = None |
|
visual_idx = None |
|
|
|
if image is not None: |
|
B, T, C, H, W = image.shape |
|
image = image.permute(0, 2, 1, 3, 4) |
|
prompt_image_embeds = self.encode_vision(image, instruction=instruction) |
|
visual = prompt_image_embeds |
|
prompt_image_embeds = self.project_up(prompt_image_embeds) |
|
prompt_image_embeds = prompt_image_embeds.view(-1, prompt_image_embeds.shape[-1]) |
|
visual_idx = image_idx |
|
text_embeds[image_idx == 1] = text_embeds[image_idx == 1] * 0 + prompt_image_embeds.to(text_embeds.device) |
|
elif video is not None: |
|
if len(video.shape) == 5: |
|
B, T, C, H, W = video.shape |
|
N = 1 |
|
else: |
|
B, N, T, C, H, W = video.shape |
|
video = video.reshape(B*N, T, C, H, W).permute(0, 2, 1, 3, 4) |
|
prompt_video_embeds = self.encode_vision(video, instruction=instruction) |
|
visual = prompt_video_embeds |
|
prompt_video_embeds = self.project_up(prompt_video_embeds) |
|
prompt_video_embeds = prompt_video_embeds.view(-1, prompt_video_embeds.shape[-1]) |
|
visual_idx = video_idx |
|
text_embeds[video_idx == 1] = text_embeds[video_idx == 1] * 0 + prompt_video_embeds.to(text_embeds.device).to(text_embeds.dtype) |
|
else: |
|
logger.warn(f"don't get visual input, input_ids: {input_ids}") |
|
|
|
if return_visual: |
|
return text_embeds, visual, visual_idx |
|
|
|
return text_embeds |
|
|
|
|
|
def encode_vision( |
|
self, |
|
image, |
|
instruction |
|
): |
|
device = image.device |
|
B = image.shape[0] |
|
T = image.shape[2] |
|
use_image = True if T == 1 else False |
|
image_embeds = self.vision_encoder(image, use_image=use_image) |
|
C = image_embeds.shape[-1] |
|
image_embeds = image_embeds.reshape(B, -1, C) |
|
image_embeds = self.vision_layernorm(image_embeds).to(device) |
|
|
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) |
|
if self.extra_num_query_token > 0: |
|
query_tokens = torch.cat([self.query_tokens, self.extra_query_tokens], dim=1) |
|
query_tokens = query_tokens.expand(image_embeds.shape[0], -1, -1) |
|
if instruction is not None: |
|
text_Qformer = self.qformer_tokenizer( |
|
instruction, |
|
padding='longest', |
|
truncation=True, |
|
max_length=512, |
|
return_tensors="pt", |
|
).to(image_embeds.device) |
|
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device) |
|
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1) |
|
query_output = self.qformer.bert( |
|
text_Qformer.input_ids, |
|
attention_mask=Qformer_atts, |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_atts, |
|
return_dict=True, |
|
) |
|
else: |
|
query_output = self.qformer.bert( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_atts, |
|
return_dict=True, |
|
) |
|
|
|
return query_output.last_hidden_state[:, :query_tokens.size(1), :] |
|
|
|
|
|
def generate_caption( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
image_idx = None, |
|
video_idx = None, |
|
image: Optional[torch.Tensor] = None, |
|
video: Optional[torch.Tensor] = None, |
|
num_beams=1, |
|
max_new_tokens=200, |
|
do_sample=True, |
|
top_p=0.9, |
|
top_k=None, |
|
temperature=1.0, |
|
length_penalty=1, |
|
repetition_penalty=1.0, |
|
): |
|
text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, image_idx=image_idx, video_idx=video_idx) |
|
outputs = self.lm.generate( |
|
inputs_embeds=text_embeds, |
|
attention_mask=attention_mask, |
|
num_beams=num_beams, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=do_sample, |
|
min_length=1, |
|
top_p=top_p, |
|
top_k=top_k, |
|
temperature=temperature, |
|
length_penalty=length_penalty, |
|
repetition_penalty=repetition_penalty, |
|
) |
|
|
|
return outputs |