Last commit not found
# -------------------------------------------------------- | |
# InternVL | |
# Copyright (c) 2023 OpenGVLab | |
# Licensed under The MIT License [see LICENSE for details] | |
# -------------------------------------------------------- | |
import warnings | |
from typing import Any, List, Optional, Tuple, Union | |
import torch.distributed as dist | |
import torch.utils.checkpoint | |
from peft import LoraConfig, get_peft_model | |
from torch import nn | |
from torch.nn import CrossEntropyLoss | |
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer | |
from transformers.generation.logits_process import LogitsProcessorList | |
from transformers.generation.stopping_criteria import StoppingCriteriaList | |
from transformers.generation.streamers import BaseStreamer | |
from transformers.modeling_outputs import CausalLMOutputWithPast | |
from transformers.modeling_utils import PreTrainedModel | |
from transformers.utils import ModelOutput, logging | |
from transformers.generation.utils import GreedySearchOutput, validate_stopping_criteria, GreedySearchDecoderOnlyOutput,GreedySearchEncoderDecoderOutput | |
from .configuration_internvl_chat import InternVLChatConfig | |
from .modeling_intern_vit import InternVisionModel | |
logger = logging.get_logger(__name__) | |
# modified from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py | |
# Fix bug when using device_map='auto' for distributed inference | |
class MLlamaForCausalLM(LlamaForCausalLM): | |
def greedy_search( | |
self, | |
input_ids: torch.LongTensor, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
max_length: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[Union[int, List[int]]] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_scores: Optional[bool] = None, | |
return_dict_in_generate: Optional[bool] = None, | |
synced_gpus: bool = False, | |
streamer: Optional["BaseStreamer"] = None, | |
**model_kwargs, | |
) -> Union[GreedySearchOutput, torch.LongTensor]: | |
# init values | |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |
if max_length is not None: | |
warnings.warn( | |
"`max_length` is deprecated in this function, use" | |
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", | |
UserWarning, | |
) | |
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | |
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores | |
output_attentions = ( | |
output_attentions if output_attentions is not None else self.generation_config.output_attentions | |
) | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states | |
) | |
return_dict_in_generate = ( | |
return_dict_in_generate | |
if return_dict_in_generate is not None | |
else self.generation_config.return_dict_in_generate | |
) | |
# init attention / hidden states / scores tuples | |
scores = () if (return_dict_in_generate and output_scores) else None | |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
if return_dict_in_generate and self.config.is_encoder_decoder: | |
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
encoder_hidden_states = ( | |
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
) | |
# keep track of which sequences are already finished | |
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
this_peer_finished = False # used by synced_gpus only | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
# prepare model inputs | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
# forward pass to get next token | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
if synced_gpus and this_peer_finished: | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# pre-process distribution | |
next_tokens_scores = logits_processor(input_ids, next_token_logits) | |
# Store scores, attentions and hidden_states when required | |
if return_dict_in_generate: | |
if output_scores: | |
scores += (next_tokens_scores,) | |
if output_attentions: | |
decoder_attentions += ( | |
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
) | |
if self.config.is_encoder_decoder: | |
cross_attentions += (outputs.cross_attentions,) | |
if output_hidden_states: | |
decoder_hidden_states += ( | |
(outputs.decoder_hidden_states,) | |
if self.config.is_encoder_decoder | |
else (outputs.hidden_states,) | |
) | |
# argmax | |
next_tokens = torch.argmax(next_tokens_scores, dim=-1).to(device=input_ids.device) | |
# finished sentences should have their next token be a padding token | |
if eos_token_id is not None: | |
if pad_token_id is None: | |
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
# update generated ids, model inputs, and length for next step | |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
if streamer is not None: | |
streamer.put(next_tokens.cpu()) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
# if eos_token was found in one sentence, set sentence to finished | |
if eos_token_id_tensor is not None: | |
unfinished_sequences = unfinished_sequences.mul( | |
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) | |
) | |
# stop when each sentence is finished | |
if unfinished_sequences.max() == 0: | |
this_peer_finished = True | |
# stop if we exceed the maximum length | |
if stopping_criteria(input_ids, scores): | |
this_peer_finished = True | |
if this_peer_finished and not synced_gpus: | |
break | |
if streamer is not None: | |
streamer.end() | |
if return_dict_in_generate: | |
if self.config.is_encoder_decoder: | |
return GreedySearchEncoderDecoderOutput( | |
sequences=input_ids, | |
scores=scores, | |
encoder_attentions=encoder_attentions, | |
encoder_hidden_states=encoder_hidden_states, | |
decoder_attentions=decoder_attentions, | |
cross_attentions=cross_attentions, | |
decoder_hidden_states=decoder_hidden_states, | |
past_key_values=model_kwargs.get("past_key_values"), | |
) | |
else: | |
return GreedySearchDecoderOnlyOutput( | |
sequences=input_ids, | |
scores=scores, | |
attentions=decoder_attentions, | |
hidden_states=decoder_hidden_states, | |
past_key_values=model_kwargs.get("past_key_values"), | |
) | |
else: | |
return input_ids | |
class InternVLChatModel(PreTrainedModel): | |
config_class = InternVLChatConfig | |
main_input_name = 'pixel_values' | |
_no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer'] | |
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None): | |
super().__init__(config) | |
image_size = config.force_image_size or config.vision_config.image_size | |
patch_size = config.vision_config.patch_size | |
self.select_layer = config.select_layer | |
self.template = config.template | |
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) | |
self.downsample_ratio = config.downsample_ratio | |
logger.info(f'num_image_token: {self.num_image_token}') | |
if vision_model is not None: | |
self.vision_model = vision_model | |
else: | |
self.vision_model = InternVisionModel(config.vision_config) | |
if language_model is not None: | |
self.language_model = language_model | |
else: | |
# self.language_model = LlamaForCausalLM(config.llm_config) | |
self.language_model = MLlamaForCausalLM(config.llm_config) | |
vit_hidden_size = config.vision_config.hidden_size | |
llm_hidden_size = config.llm_config.hidden_size | |
self.mlp1 = nn.Sequential( | |
nn.LayerNorm(vit_hidden_size * 4), | |
nn.Linear(vit_hidden_size * 4, llm_hidden_size), | |
nn.GELU(), | |
nn.Linear(llm_hidden_size, llm_hidden_size) | |
) | |
if config.force_image_size != config.vision_config.image_size: | |
self.vision_model.resize_pos_embeddings( | |
old_size=config.vision_config.image_size, | |
new_size=config.force_image_size, | |
patch_size=config.vision_config.patch_size | |
) | |
self.img_context_token_id = None | |
if config.use_backbone_lora: | |
self.wrap_backbone_lora(r=config.use_backbone_lora) | |
if config.use_llm_lora: | |
self.wrap_llm_lora(r=config.use_llm_lora) | |
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): | |
lora_config = LoraConfig( | |
r=r, | |
target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'], | |
lora_alpha=lora_alpha, | |
lora_dropout=lora_dropout, | |
) | |
self.vision_model = get_peft_model(self.vision_model, lora_config) | |
self.vision_model.print_trainable_parameters() | |
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): | |
lora_config = LoraConfig( | |
r=r, | |
target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj', | |
'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'], | |
lora_alpha=lora_alpha, | |
lora_dropout=lora_dropout, | |
task_type='CAUSAL_LM' | |
) | |
self.language_model = get_peft_model(self.language_model, lora_config) | |
self.language_model.print_trainable_parameters() | |
def forward( | |
self, | |
pixel_values: torch.FloatTensor, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
image_flags: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
image_flags = image_flags.squeeze(-1) | |
input_embeds = self.language_model.get_input_embeddings()(input_ids) | |
vit_embeds = self.extract_feature(pixel_values) | |
vit_embeds = vit_embeds[image_flags == 1] | |
B, N, C = input_embeds.shape | |
input_embeds = input_embeds.reshape(B * N, C) | |
input_ids = input_ids.reshape(B * N) | |
selected = (input_ids == self.img_context_token_id) | |
try: | |
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) | |
except: | |
pass | |
input_embeds = input_embeds.reshape(B, N, C) | |
outputs = self.language_model.model( | |
inputs_embeds=input_embeds, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs[0] | |
logits = self.language_model.lm_head(hidden_states) | |
loss = None | |
if labels is not None: | |
# Shift so that tokens < n predict n | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss() | |
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) | |
shift_labels = shift_labels.view(-1) | |
# Enable model parallelism | |
shift_labels = shift_labels.to(shift_logits.device) | |
loss = loss_fct(shift_logits, shift_labels) | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return (loss,) + output if loss is not None else output | |
return CausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def pixel_shuffle(self, x, scale_factor=0.5): | |
n, w, h, c = x.size() | |
# N, W, H, C --> N, W, H * scale, C // scale | |
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) | |
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale | |
x = x.permute(0, 2, 1, 3).contiguous() | |
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) | |
x = x.view(n, int(h * scale_factor), int(w * scale_factor), | |
int(c / (scale_factor * scale_factor))) | |
return x | |
def extract_feature(self, pixel_values): | |
if self.select_layer == -1: | |
vit_embeds = self.vision_model( | |
pixel_values=pixel_values, | |
output_hidden_states=False, | |
return_dict=True).last_hidden_state | |
else: | |
vit_embeds = self.vision_model( | |
pixel_values=pixel_values, | |
output_hidden_states=True, | |
return_dict=True).hidden_states[self.select_layer] | |
vit_embeds = vit_embeds[:, 1:, :] | |
# if torch.distributed.get_rank() == 0: | |
# print("before pixel shuffle:", vit_embeds.shape) | |
h = w = int(vit_embeds.shape[1] ** 0.5) | |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) | |
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) | |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) | |
# if torch.distributed.get_rank() == 0: | |
# print("after pixel shuffle:", vit_embeds.shape) | |
vit_embeds = self.mlp1(vit_embeds) | |
return vit_embeds | |
def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, | |
IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'): | |
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) | |
self.img_context_token_id = img_context_token_id | |
from .conversation import get_conv_template | |
template = get_conv_template(self.template) | |
if history is None: | |
history = [] | |
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token + IMG_END_TOKEN | |
question = image_tokens + '\n' + question | |
else: | |
for (old_question, old_answer) in history: | |
template.append_message(template.roles[0], old_question) | |
template.append_message(template.roles[1], old_answer) | |
template.append_message(template.roles[0], question) | |
template.append_message(template.roles[1], None) | |
query = template.get_prompt() | |
model_inputs = tokenizer(query, return_tensors='pt') | |
input_ids = model_inputs['input_ids'].cuda() | |
attention_mask = model_inputs['attention_mask'].cuda() | |
generation_output = self.generate( | |
pixel_values=pixel_values, | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
**generation_config | |
) | |
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] | |
history.append((question, response)) | |
if return_history: | |
return response, history | |
else: | |
return response | |
def generate( | |
self, | |
pixel_values: Optional[torch.FloatTensor] = None, | |
input_ids: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.LongTensor] = None, | |
visual_features: Optional[torch.FloatTensor] = None, | |
generation_config: Optional[GenerationConfig] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
**generate_kwargs, | |
) -> torch.LongTensor: | |
assert self.img_context_token_id is not None | |
if pixel_values is not None: | |
if visual_features is not None: | |
vit_embeds = visual_features | |
else: | |
vit_embeds = self.extract_feature(pixel_values) | |
input_embeds = self.language_model.get_input_embeddings()(input_ids) | |
B, N, C = input_embeds.shape | |
input_embeds = input_embeds.reshape(B * N, C) | |
input_ids = input_ids.reshape(B * N) | |
selected = (input_ids == self.img_context_token_id) | |
assert selected.sum() != 0 | |
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) | |
input_embeds = input_embeds.reshape(B, N, C) | |
else: | |
input_embeds = self.language_model.get_input_embeddings()(input_ids) | |
outputs = self.language_model.generate( | |
inputs_embeds=input_embeds, | |
attention_mask=attention_mask, | |
generation_config=generation_config, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
use_cache=True, | |
**generate_kwargs, | |
) | |
return outputs | |