|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from einops import rearrange |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
LlamaConfig, |
|
LlamaForCausalLM, |
|
PreTrainedModel, |
|
GenerationMixin |
|
) |
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
from .clip_encoder import CLIPVisionTower |
|
from .siglip_vit import create_siglip_vit |
|
from .projector import MlpProjector |
|
from .configuration_vlm import AttrDict, MultiModalityConfig, VisionConfig, AlignerConfig, GenVisionConfig, GenHeadConfig, GenAlignerConfig |
|
|
|
|
|
class vision_head(torch.nn.Module): |
|
def __init__(self, params): |
|
super().__init__() |
|
self.output_mlp_projector = torch.nn.Linear( |
|
params.n_embed, params.image_token_embed |
|
) |
|
self.vision_activation = torch.nn.GELU() |
|
self.vision_head = torch.nn.Linear( |
|
params.image_token_embed, params.image_token_size |
|
) |
|
|
|
def forward(self, x): |
|
x = self.output_mlp_projector(x) |
|
x = self.vision_activation(x) |
|
x = self.vision_head(x) |
|
return x |
|
|
|
|
|
def model_name_to_cls(cls_name): |
|
if "MlpProjector" in cls_name: |
|
cls = MlpProjector |
|
|
|
elif "CLIPVisionTower" in cls_name: |
|
cls = CLIPVisionTower |
|
|
|
elif "VQ" in cls_name: |
|
from janus.models.vq_model import VQ_models |
|
|
|
cls = VQ_models[cls_name] |
|
elif "vision_head" in cls_name: |
|
cls = vision_head |
|
else: |
|
raise ValueError(f"class_name {cls_name} is invalid.") |
|
|
|
return cls |
|
|
|
|
|
class MultiModalityPreTrainedModel(PreTrainedModel): |
|
config_class = MultiModalityConfig |
|
base_model_prefix = "multi_modality" |
|
_no_split_modules = [] |
|
_skip_keys_device_placement = "past_key_values" |
|
|
|
|
|
class MultiModalityCausalLM(MultiModalityPreTrainedModel): |
|
def __init__(self, config: MultiModalityConfig): |
|
super().__init__(config) |
|
|
|
vision_config = config.vision_config |
|
vision_cls = model_name_to_cls(vision_config.cls) |
|
self.vision_model = vision_cls(**vision_config.params) |
|
|
|
aligner_config = config.aligner_config |
|
aligner_cls = model_name_to_cls(aligner_config.cls) |
|
self.aligner = aligner_cls(aligner_config.params) |
|
|
|
gen_vision_config = config.gen_vision_config |
|
gen_vision_cls = model_name_to_cls(gen_vision_config.cls) |
|
self.gen_vision_model = gen_vision_cls() |
|
|
|
gen_aligner_config = config.gen_aligner_config |
|
gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls) |
|
self.gen_aligner = gen_aligner_cls(gen_aligner_config.params) |
|
|
|
gen_head_config = config.gen_head_config |
|
gen_head_cls = model_name_to_cls(gen_head_config.cls) |
|
self.gen_head = gen_head_cls(gen_head_config.params) |
|
|
|
self.gen_embed = torch.nn.Embedding( |
|
gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed |
|
) |
|
|
|
language_config = config.language_config |
|
self.language_model = LlamaForCausalLM(language_config) |
|
|
|
def prepare_inputs_embeds( |
|
self, |
|
input_ids: torch.LongTensor, |
|
pixel_values: torch.FloatTensor, |
|
images_seq_mask: torch.LongTensor, |
|
images_emb_mask: torch.LongTensor, |
|
**kwargs, |
|
): |
|
""" |
|
|
|
Args: |
|
input_ids (torch.LongTensor): [b, T] |
|
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] |
|
images_seq_mask (torch.BoolTensor): [b, T] |
|
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] |
|
|
|
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) |
|
|
|
Returns: |
|
input_embeds (torch.Tensor): [b, T, D] |
|
""" |
|
|
|
bs, n = pixel_values.shape[0:2] |
|
images = rearrange(pixel_values, "b n c h w -> (b n) c h w") |
|
|
|
images_embeds = self.aligner(self.vision_model(images)) |
|
|
|
|
|
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) |
|
|
|
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") |
|
|
|
|
|
input_ids[input_ids < 0] = 0 |
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) |
|
|
|
|
|
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] |
|
|
|
return inputs_embeds |
|
|
|
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): |
|
return self.gen_aligner(self.gen_embed(image_ids)) |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
pixel_values=None, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
attention_mask=None, |
|
position_ids=None, |
|
images_seq_mask=None, |
|
images_emb_mask=None, |
|
**kwargs, |
|
): |
|
if inputs_embeds is None: |
|
inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs) |
|
return self.language_model.forward( |
|
input_ids=None, |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
**kwargs, |
|
) |
|
|
|
def generate( |
|
self, |
|
input_ids=None, |
|
pixel_values=None, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
attention_mask=None, |
|
position_ids=None, |
|
images_seq_mask=None, |
|
images_emb_mask=None, |
|
**kwargs |
|
): |
|
if inputs_embeds is None: |
|
inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs) |
|
return self.language_model.generate(inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, **kwargs) |
|
|
|
|
|
|
|
|
|
AutoConfig.register("vision", VisionConfig) |
|
AutoConfig.register("aligner", AlignerConfig) |
|
AutoConfig.register("gen_vision", GenVisionConfig) |
|
AutoConfig.register("gen_aligner", GenAlignerConfig) |
|
AutoConfig.register("gen_head", GenHeadConfig) |
|
AutoConfig.register("multi_modality", MultiModalityConfig) |
|
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM) |
|
|