|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from attrdict import AttrDict |
|
from einops import rearrange |
|
import torch |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
PreTrainedModel, |
|
LlamaConfig, |
|
LlamaForCausalLM, |
|
) |
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm |
|
from janus.janusflow.models.clip_encoder import CLIPVisionTower |
|
from janus.janusflow.models.uvit import ShallowUViTEncoder, ShallowUViTDecoder |
|
import torch.nn as nn |
|
|
|
|
|
def model_name_to_cls(cls_name): |
|
|
|
if "CLIPVisionTower" in cls_name: |
|
cls = CLIPVisionTower |
|
elif "ShallowUViTEncoder" in cls_name: |
|
cls = ShallowUViTEncoder |
|
elif "ShallowUViTDecoder" in cls_name: |
|
cls = ShallowUViTDecoder |
|
else: |
|
raise ValueError(f"class_name {cls_name} is invalid.") |
|
|
|
return cls |
|
|
|
|
|
class VisionUnderstandEncoderConfig(PretrainedConfig): |
|
model_type = "vision_und_enc" |
|
cls: str = "" |
|
params: AttrDict = {} |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.cls = kwargs.get("cls", "") |
|
if not isinstance(self.cls, str): |
|
self.cls = self.cls.__name__ |
|
|
|
self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
|
class VisionGenerationEncoderConfig(PretrainedConfig): |
|
model_type = "vision_gen_enc" |
|
cls: str = "" |
|
params: AttrDict = {} |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.cls = kwargs.get("cls", "") |
|
if not isinstance(self.cls, str): |
|
self.cls = self.cls.__name__ |
|
|
|
self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
|
class VisionGenerationDecoderConfig(PretrainedConfig): |
|
model_type = "vision_gen_dec" |
|
cls: str = "" |
|
params: AttrDict = {} |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.cls = kwargs.get("cls", "") |
|
if not isinstance(self.cls, str): |
|
self.cls = self.cls.__name__ |
|
|
|
self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
|
class MultiModalityConfig(PretrainedConfig): |
|
model_type = "multi_modality" |
|
vision_und_enc_config: VisionUnderstandEncoderConfig |
|
language_config: LlamaConfig |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
vision_und_enc_config = kwargs.get("vision_und_enc_config", {}) |
|
self.vision_und_enc_config = VisionUnderstandEncoderConfig( |
|
**vision_und_enc_config |
|
) |
|
|
|
vision_gen_enc_config = kwargs.get("vision_gen_enc_config", {}) |
|
self.vision_gen_enc_config = VisionGenerationEncoderConfig( |
|
**vision_gen_enc_config |
|
) |
|
|
|
vision_gen_dec_config = kwargs.get("vision_gen_dec_config", {}) |
|
self.vision_gen_dec_config = VisionGenerationDecoderConfig( |
|
**vision_gen_dec_config |
|
) |
|
|
|
language_config = kwargs.get("language_config", {}) |
|
if isinstance(language_config, LlamaConfig): |
|
self.language_config = language_config |
|
else: |
|
self.language_config = LlamaConfig(**language_config) |
|
|
|
|
|
class MultiModalityPreTrainedModel(PreTrainedModel): |
|
config_class = MultiModalityConfig |
|
base_model_prefix = "multi_modality" |
|
_no_split_modules = [] |
|
_skip_keys_device_placement = "past_key_values" |
|
|
|
|
|
class MultiModalityCausalLM(MultiModalityPreTrainedModel): |
|
|
|
def __init__(self, config: MultiModalityConfig): |
|
super().__init__(config) |
|
|
|
|
|
vision_und_enc_config = config.vision_und_enc_config |
|
vision_und_enc_cls = model_name_to_cls(vision_und_enc_config.cls) |
|
self.vision_und_enc_model = vision_und_enc_cls(**vision_und_enc_config.params) |
|
|
|
|
|
self.vision_und_enc_aligner = nn.Linear(1024, 2048, bias=True) |
|
|
|
|
|
self.beg_of_und_embed = nn.Parameter(torch.zeros(1, 2048)) |
|
|
|
|
|
vision_gen_enc_config = config.vision_gen_enc_config |
|
vision_gen_enc_cls = model_name_to_cls(vision_gen_enc_config.cls) |
|
self.vision_gen_enc_model = vision_gen_enc_cls(**vision_gen_enc_config.params) |
|
|
|
|
|
self.vision_gen_enc_aligner = nn.Linear(768, 2048, bias=True) |
|
|
|
|
|
vision_gen_dec_config = config.vision_gen_dec_config |
|
vision_gen_dec_cls = model_name_to_cls(vision_gen_dec_config.cls) |
|
self.vision_gen_dec_model = vision_gen_dec_cls(**vision_gen_dec_config.params) |
|
|
|
|
|
language_config = config.language_config |
|
self.language_model = LlamaForCausalLM(language_config) |
|
|
|
|
|
self.vision_gen_dec_aligner_norm = LlamaRMSNorm( |
|
2048, eps=language_config.rms_norm_eps |
|
) |
|
self.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True) |
|
|
|
def prepare_inputs_embeds( |
|
self, |
|
input_ids: torch.LongTensor, |
|
pixel_values: torch.FloatTensor, |
|
images_seq_mask: torch.LongTensor, |
|
images_emb_mask: torch.LongTensor, |
|
**kwargs, |
|
): |
|
""" |
|
|
|
Args: |
|
input_ids (torch.LongTensor): [b, T] |
|
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] |
|
images_seq_mask (torch.BoolTensor): [b, T] |
|
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] |
|
|
|
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) |
|
|
|
Returns: |
|
input_embeds (torch.Tensor): [b, T, D] |
|
""" |
|
|
|
bs, n = pixel_values.shape[0:2] |
|
images = rearrange(pixel_values, "b n c h w -> (b n) c h w") |
|
|
|
images_embeds = self.vision_und_enc_model(images) |
|
images_embeds = self.vision_und_enc_aligner(images_embeds) |
|
|
|
beg_of_und_embed = self.beg_of_und_embed[0].detach().clone() |
|
images_embeds = torch.cat( |
|
[ |
|
beg_of_und_embed.view(1, 1, -1).repeat(images_embeds.shape[0], 1, 1), |
|
images_embeds, |
|
], |
|
dim=1, |
|
) |
|
|
|
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) |
|
|
|
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") |
|
|
|
|
|
input_ids[input_ids < 0] = 0 |
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) |
|
|
|
|
|
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] |
|
|
|
return inputs_embeds |
|
|
|
|
|
AutoConfig.register("vision_und_enc", VisionUnderstandEncoderConfig) |
|
AutoConfig.register("vision_gen_enc", VisionGenerationEncoderConfig) |
|
AutoConfig.register("vision_gen_dec", VisionGenerationDecoderConfig) |
|
AutoConfig.register("multi_modality", MultiModalityConfig) |
|
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM) |
|
|