# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. import warnings import nncore import torch import torch.nn as nn import torch.nn.functional as F from nncore.nn import ModuleList, PositionalEncoding, Sequential, TransformerEncoderLayer, xavier_init_ from nncore.ops import temporal_iou from transformers import AutoConfig, AutoModel, Qwen2VLConfig, Qwen2VLForConditionalGeneration, Qwen2VLModel from transformers.activations import ACT2CLS, ACT2FN from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel from .blocks import ConvHead, ConvPyramid, LearnableEmbedding, Scale from .generator import PointGenerator from .loss import BundleLoss def cache_state_hook(module, args): module.state = args[0] class AgentQwen2VLConfig(Qwen2VLConfig): model_type = 'agent_qwen2_vl' class AgentQwen2VisionTransformerPretrainedModel(Qwen2VisionTransformerPretrainedModel): def __init__(self, config): super().__init__(config) self.gradient_checkpointing = False # add support for gradient checkpointing # https://github.com/huggingface/transformers/pull/34724 def forward(self, hidden_states, grid_thw): hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for blk in self.blocks: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb) else: hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) return self.merger(hidden_states) class AgentQwen2VLModel(Qwen2VLModel): config_class = AgentQwen2VLConfig def __init__(self, config): super().__init__(config) self.norm.register_forward_pre_hook(cache_state_hook) def forward(self, input_ids=None, inputs_embeds=None, **kwargs): # ensure gradient tracking (in case that embed_tokens has been frozen) assert input_ids is None and inputs_embeds is not None if self.training and not inputs_embeds.requires_grad: inputs_embeds.requires_grad = True return super().forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs) class AgentQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration): config_class = AgentQwen2VLConfig def __init__(self, config): super().__init__(config) self.visual = AgentQwen2VisionTransformerPretrainedModel._from_config(config.vision_config) self.model = AgentQwen2VLModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.rope_deltas = None if self.config.role in ('all_in_one', 'grounder'): hidden_size, hidden_act = self.config.hidden_size, self.config.hidden_act self.dims = 256 self.vis_proj = Sequential(nn.LayerNorm(hidden_size), nn.Linear(hidden_size, self.dims)) self.reg_proj = Sequential(nn.LayerNorm(hidden_size), nn.Linear(hidden_size, self.dims)) self.vis_norm = nn.LayerNorm(self.dims) self.vis_fuse = ModuleList( TransformerEncoderLayer(self.dims, act_cfg=ACT2FN[hidden_act]), TransformerEncoderLayer(self.dims, act_cfg=ACT2FN[hidden_act]), TransformerEncoderLayer(self.dims, act_cfg=ACT2FN[hidden_act])) self.vis_pos = PositionalEncoding(self.dims, normalize=True, learnable=False) self.vis_emb = LearnableEmbedding(self.dims) self.reg_emb = LearnableEmbedding(self.dims) self.strides = (1, 2, 4, 8) self.vis_pad_length = self.strides[-1] self.pyramid = ConvPyramid(self.dims, self.strides, act_cls=ACT2CLS[hidden_act]) self.class_head = ConvHead(self.dims, 1, act_cls=ACT2CLS[hidden_act]) self.coord_head = ConvHead(self.dims, 2, act_cls=ACT2CLS[hidden_act]) self.generator = PointGenerator(self.strides, 1024) self.coef = Scale(self.strides) self.bundle_loss = BundleLoss( sample_radius=1.5, loss_cls=dict(type='FocalLoss', reduction='none', loss_weight=5.0), loss_reg=dict(type='L1Loss', reduction='none', loss_weight=1.0), loss_sal=dict(type='SampledNCELoss', direction='row', loss_weight=0.05)) self.post_init() def reset_conv_parameters(self): for s in ('pyramid', 'class_head', 'coord_head'): b = getattr(self, s, None) if b is None: continue for n, m in b.named_modules(): if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)): print(f'Reset parameters of {b.__class__.__name__} {n} ({m.__class__.__name__})') xavier_init_(m, distribution='uniform') def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, pixel_values=None, pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, rope_deltas=None, timestamps=None, saliency=None, pos_clip=None): mode = 'training' if self.training else 'caching' if ( past_key_values is None or len(past_key_values) == 0) else 'generating' # https://github.com/huggingface/transformers/pull/33487 if position_ids is None and input_ids is not None: position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) if mode in ('training', 'caching'): vision_s_inds = torch.nonzero(input_ids == self.config.vision_start_token_id).tolist() vision_e_inds = torch.nonzero(input_ids == self.config.vision_end_token_id).tolist() assert len(vision_s_inds) == len(vision_e_inds) self.cache_vision_inds = [[] for _ in range(input_ids.size(0))] for i in range(len(vision_s_inds)): assert vision_s_inds[i][0] == vision_e_inds[i][0] self.cache_vision_inds[vision_s_inds[i][0]].append([vision_s_inds[i][1] + 1, vision_e_inds[i][1]]) outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=not self.training, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, rope_deltas=rope_deltas) if mode == 'caching': self.cache_norm_state = self.model.norm.state self.reg = [] self.sal = [] if mode == 'training' and timestamps is not None: loss_regs, avg_factors = [], [] shift_labels = labels[..., 1:].contiguous() for batch_idx, (vision_inds, ts) in enumerate(zip(self.cache_vision_inds, timestamps)): # only consider the first video s, e = vision_inds[0] # spatial merge size set to 2 window = int(video_grid_thw[0][1] * video_grid_thw[0][2] / 4) assert video_grid_thw[0][0] * window == e - s inds = torch.where(shift_labels[batch_idx] == self.config.reg_token_id)[0] reg_tokens = self.reg_proj(self.model.norm.state[batch_idx, inds, None]) # reg_tokens: num_reg_tokens * 1 * channel vis_tokens = self.model.norm.state[batch_idx, None, s:e] vis_tokens = vis_tokens.transpose(-1, -2) vis_tokens = F.avg_pool1d(vis_tokens.float(), window, stride=window).to(vis_tokens.dtype) vis_tokens = vis_tokens.transpose(-1, -2) vis_tokens = self.vis_proj(vis_tokens).repeat(reg_tokens.size(0), 1, 1) # vis_tokens: num_reg_tokens * num_frames * channel vis_tokens = self.vis_emb(vis_tokens) reg_tokens = self.reg_emb(reg_tokens) pe = self.vis_pos(vis_tokens).to(vis_tokens.dtype) joint_tokens = torch.cat((vis_tokens + pe, reg_tokens), dim=1) collected = [joint_tokens] for blk in self.vis_fuse: collected.append(blk(collected[-1])) collected = collected[1:] joint_tokens = torch.cat(collected) joint_tokens = self.vis_norm(joint_tokens) video_emb = joint_tokens[:, :-1] # video_emb: num_reg_tokens * num_frames * channel query_emb = joint_tokens[:, -1:] # query_emb: num_reg_tokens * 1 * channel b, t, c = video_emb.size() video_msk = video_emb.new_ones(b, t) if t < self.vis_pad_length: emb_pad = video_emb.new_zeros(b, self.vis_pad_length - t, c) msk_pad = video_msk.new_zeros(b, self.vis_pad_length - t) pymid_emb = torch.cat((video_emb, emb_pad), dim=1) pymid_msk = torch.cat((video_msk, msk_pad), dim=1) else: pymid_emb, pymid_msk = video_emb, video_msk pymid, pymid_msk = self.pyramid(pymid_emb, pymid_msk, return_mask=True) if not len(pymid) == len(pymid_msk) == len(self.strides): warnings.warn(f'pyramid size mismatch: {len(pymid)} {len(pymid_msk)} {len(self.strides)}') point = self.generator(pymid) out_class = [self.class_head(e) for e in pymid] out_class = torch.cat(out_class, dim=1) out_coord = [self.coef(self.coord_head(e).exp(), i) for i, e in enumerate(pymid)] out_coord = torch.cat(out_coord, dim=1) data = dict( point=point, video_emb=video_emb, query_emb=query_emb, video_msk=video_msk, pymid_msk=pymid_msk, out_class=out_class, out_coord=out_coord, boundary=point.new_tensor(ts), saliency=saliency[batch_idx].unsqueeze(0), pos_clip=pos_clip[batch_idx].unsqueeze(0)) losses = self.bundle_loss(data, dict()) # print({k: v.item() for k, v in losses.items()}) loss_regs.append(sum(v for v in losses.values())) avg_factors.append(len(ts)) assert len(loss_regs) in (1, 2) and len(loss_regs) == len(avg_factors) if len(loss_regs) == 2 and loss_regs[0] > loss_regs[1]: loss_reg, avg_factor = loss_regs[1], avg_factors[1] else: loss_reg, avg_factor = loss_regs[0], avg_factors[0] if avg_factor > 0: outputs.loss = outputs.loss + loss_reg / avg_factor elif mode == 'generating': logits = outputs.logits[0, -1] if logits.argmax() == self.config.reg_token_id: assert self.model.norm.state.size() == (1, 1, self.config.hidden_size) # only consider the first video s, e = self.cache_vision_inds[0][0] # spatial merge size set to 2 window = int(video_grid_thw[0][1] * video_grid_thw[0][2] / 4) assert video_grid_thw[0][0] * window == e - s reg_tokens = self.reg_proj(self.model.norm.state) # reg_tokens: num_reg_tokens * 1 * channel vis_tokens = self.cache_norm_state[:, s:e] vis_tokens = vis_tokens.transpose(-1, -2) vis_tokens = F.avg_pool1d(vis_tokens.float(), window, stride=window).to(vis_tokens.dtype) vis_tokens = vis_tokens.transpose(-1, -2) vis_tokens = self.vis_proj(vis_tokens).repeat(reg_tokens.size(0), 1, 1) # vis_tokens: num_reg_tokens * num_frames * channel vis_tokens = self.vis_emb(vis_tokens) reg_tokens = self.reg_emb(reg_tokens) pe = self.vis_pos(vis_tokens).to(vis_tokens.dtype) joint_tokens = torch.cat((vis_tokens + pe, reg_tokens), dim=1) for blk in self.vis_fuse: joint_tokens = blk(joint_tokens) joint_tokens = self.vis_norm(joint_tokens) video_emb = joint_tokens[:, :-1] # video_emb: num_reg_tokens * num_frames * channel query_emb = joint_tokens[:, -1:] # query_emb: num_reg_tokens * 1 * channel b, t, _ = video_emb.size() video_msk = video_emb.new_ones(b, t) pymid = self.pyramid(video_emb, video_msk) point = self.generator(pymid) out_class = [self.class_head(e).sigmoid() for e in pymid] out_class = torch.cat(out_class, dim=1) out_coord = [self.coef(self.coord_head(e).exp(), i) for i, e in enumerate(pymid)] out_coord = torch.cat(out_coord, dim=1) sal = out_class[0] bnd = out_coord[0] bnd[:, 0] *= -1 bnd *= point[:, 3, None].repeat(1, 2) bnd += point[:, 0, None].repeat(1, 2) bnd /= t bnd = torch.cat((bnd, sal), dim=-1) _, inds = bnd[:, -1].sort(descending=True) bnd = bnd[inds] # hard coding nms config here nms_cfg = dict(type='normal', thres=0.75) assert nms_cfg['type'] in ('normal', 'linear', 'gaussian') for i in range(bnd.size(0)): max_idx = bnd[i:, -1].argmax(dim=0) bnd = nncore.swap_element(bnd, i, max_idx + i) iou = temporal_iou(bnd[i, None, :-1], bnd[i + 1:, :-1])[0] if nms_cfg['type'] == 'normal': bnd[i + 1:, -1][iou >= nms_cfg['thres']] = 0 elif nms_cfg['type'] == 'linear': bnd[i + 1:, -1] *= 1 - iou else: bnd[i + 1:, -1] *= (-iou.pow(2) / nms_cfg['sigma']).exp() # save top-100 predictions self.reg.append(bnd[:100]) # save all saliency scores self.sal.append(sal) return outputs # set the patched model to a vision model MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES[AgentQwen2VLConfig.model_type] = 'AgentQwen2VLForConditionalGeneration' AutoConfig.register(AgentQwen2VLConfig.model_type, AgentQwen2VLConfig) AutoModel.register(AgentQwen2VLConfig, AgentQwen2VLForConditionalGeneration)