File size: 15,879 Bytes
6073e55
23fdbc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655ce8d
 
 
 
23fdbc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655ce8d
23fdbc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License.

import warnings

import nncore
import torch
import torch.nn as nn
import torch.nn.functional as F
from nncore.nn import ModuleList, PositionalEncoding, Sequential, TransformerEncoderLayer, xavier_init_
from nncore.ops import temporal_iou
from transformers import AutoConfig, AutoModel, Qwen2VLConfig, Qwen2VLForConditionalGeneration, Qwen2VLModel
from transformers.activations import ACT2CLS, ACT2FN
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel

from .blocks import ConvHead, ConvPyramid, LearnableEmbedding, Scale
from .generator import PointGenerator
from .loss import BundleLoss


def cache_state_hook(module, args):
    module.state = args[0]


class AgentQwen2VLConfig(Qwen2VLConfig):
    model_type = 'agent_qwen2_vl'


class AgentQwen2VisionTransformerPretrainedModel(Qwen2VisionTransformerPretrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.gradient_checkpointing = False

    # add support for gradient checkpointing
    # https://github.com/huggingface/transformers/pull/34724
    def forward(self, hidden_states, grid_thw):
        hidden_states = self.patch_embed(hidden_states)
        rotary_pos_emb = self.rot_pos_emb(grid_thw)
        cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            dim=0, dtype=torch.int32)
        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

        for blk in self.blocks:
            if self.gradient_checkpointing and self.training:
                hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens,
                                                                  rotary_pos_emb)
            else:
                hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)

        return self.merger(hidden_states)


class AgentQwen2VLModel(Qwen2VLModel):
    config_class = AgentQwen2VLConfig

    def __init__(self, config):
        super().__init__(config)
        self.norm.register_forward_pre_hook(cache_state_hook)

    def forward(self, input_ids=None, inputs_embeds=None, **kwargs):
        # ensure gradient tracking (in case that embed_tokens has been frozen)
        assert input_ids is None and inputs_embeds is not None
        if self.training and not inputs_embeds.requires_grad:
            inputs_embeds.requires_grad = True
        return super().forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)


class AgentQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
    config_class = AgentQwen2VLConfig

    def __init__(self, config):
        super().__init__(config)
        self.visual = AgentQwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
        self.model = AgentQwen2VLModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.rope_deltas = None

        if self.config.role in ('all_in_one', 'grounder'):
            hidden_size, hidden_act = self.config.hidden_size, self.config.hidden_act

            self.dims = 256

            self.vis_proj = Sequential(nn.LayerNorm(hidden_size), nn.Linear(hidden_size, self.dims))
            self.reg_proj = Sequential(nn.LayerNorm(hidden_size), nn.Linear(hidden_size, self.dims))
            self.vis_norm = nn.LayerNorm(self.dims)
            self.vis_fuse = ModuleList(
                TransformerEncoderLayer(self.dims, act_cfg=ACT2FN[hidden_act]),
                TransformerEncoderLayer(self.dims, act_cfg=ACT2FN[hidden_act]),
                TransformerEncoderLayer(self.dims, act_cfg=ACT2FN[hidden_act]))

            self.vis_pos = PositionalEncoding(self.dims, normalize=True, learnable=False)
            self.vis_emb = LearnableEmbedding(self.dims)
            self.reg_emb = LearnableEmbedding(self.dims)

            self.strides = (1, 2, 4, 8)
            self.vis_pad_length = self.strides[-1]

            self.pyramid = ConvPyramid(self.dims, self.strides, act_cls=ACT2CLS[hidden_act])
            self.class_head = ConvHead(self.dims, 1, act_cls=ACT2CLS[hidden_act])
            self.coord_head = ConvHead(self.dims, 2, act_cls=ACT2CLS[hidden_act])

            self.generator = PointGenerator(self.strides, 1024)
            self.coef = Scale(self.strides)
            self.bundle_loss = BundleLoss(
                sample_radius=1.5,
                loss_cls=dict(type='FocalLoss', reduction='none', loss_weight=5.0),
                loss_reg=dict(type='L1Loss', reduction='none', loss_weight=1.0),
                loss_sal=dict(type='SampledNCELoss', direction='row', loss_weight=0.05))

        self.post_init()

    def reset_conv_parameters(self):
        for s in ('pyramid', 'class_head', 'coord_head'):
            b = getattr(self, s, None)
            if b is None:
                continue
            for n, m in b.named_modules():
                if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)):
                    print(f'Reset parameters of {b.__class__.__name__} {n} ({m.__class__.__name__})')
                    xavier_init_(m, distribution='uniform')

    def forward(self,
                input_ids=None,
                attention_mask=None,
                position_ids=None,
                past_key_values=None,
                inputs_embeds=None,
                labels=None,
                use_cache=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None,
                pixel_values=None,
                pixel_values_videos=None,
                image_grid_thw=None,
                video_grid_thw=None,
                rope_deltas=None,
                timestamps=None,
                saliency=None,
                pos_clip=None):
        mode = 'training' if self.training else 'caching' if (
            past_key_values is None or len(past_key_values) == 0) else 'generating'

        # https://github.com/huggingface/transformers/pull/33487
        if position_ids is None and input_ids is not None:
            position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)

        if mode in ('training', 'caching'):
            vision_s_inds = torch.nonzero(input_ids == self.config.vision_start_token_id).tolist()
            vision_e_inds = torch.nonzero(input_ids == self.config.vision_end_token_id).tolist()
            assert len(vision_s_inds) == len(vision_e_inds)

            self.cache_vision_inds = [[] for _ in range(input_ids.size(0))]
            for i in range(len(vision_s_inds)):
                assert vision_s_inds[i][0] == vision_e_inds[i][0]
                self.cache_vision_inds[vision_s_inds[i][0]].append([vision_s_inds[i][1] + 1, vision_e_inds[i][1]])

        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=not self.training,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
            pixel_values=pixel_values,
            pixel_values_videos=pixel_values_videos,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            rope_deltas=rope_deltas)

        if mode == 'caching':
            self.cache_norm_state = self.model.norm.state
            self.reg = []
            self.sal = []

        if mode == 'training' and timestamps is not None:
            loss_regs, avg_factors = [], []
            shift_labels = labels[..., 1:].contiguous()
            for batch_idx, (vision_inds, ts) in enumerate(zip(self.cache_vision_inds, timestamps)):
                # only consider the first video
                s, e = vision_inds[0]

                # spatial merge size set to 2
                window = int(video_grid_thw[0][1] * video_grid_thw[0][2] / 4)
                assert video_grid_thw[0][0] * window == e - s

                inds = torch.where(shift_labels[batch_idx] == self.config.reg_token_id)[0]
                reg_tokens = self.reg_proj(self.model.norm.state[batch_idx, inds, None])
                # reg_tokens: num_reg_tokens * 1 * channel

                vis_tokens = self.model.norm.state[batch_idx, None, s:e]
                vis_tokens = vis_tokens.transpose(-1, -2)
                vis_tokens = F.avg_pool1d(vis_tokens.float(), window, stride=window).to(vis_tokens.dtype)
                vis_tokens = vis_tokens.transpose(-1, -2)
                vis_tokens = self.vis_proj(vis_tokens).repeat(reg_tokens.size(0), 1, 1)
                # vis_tokens: num_reg_tokens * num_frames * channel

                vis_tokens = self.vis_emb(vis_tokens)
                reg_tokens = self.reg_emb(reg_tokens)
                pe = self.vis_pos(vis_tokens).to(vis_tokens.dtype)

                joint_tokens = torch.cat((vis_tokens + pe, reg_tokens), dim=1)
                collected = [joint_tokens]
                for blk in self.vis_fuse:
                    collected.append(blk(collected[-1]))
                collected = collected[1:]
                joint_tokens = torch.cat(collected)
                joint_tokens = self.vis_norm(joint_tokens)

                video_emb = joint_tokens[:, :-1]
                # video_emb: num_reg_tokens * num_frames * channel

                query_emb = joint_tokens[:, -1:]
                # query_emb: num_reg_tokens * 1 * channel

                b, t, c = video_emb.size()
                video_msk = video_emb.new_ones(b, t)

                if t < self.vis_pad_length:
                    emb_pad = video_emb.new_zeros(b, self.vis_pad_length - t, c)
                    msk_pad = video_msk.new_zeros(b, self.vis_pad_length - t)
                    pymid_emb = torch.cat((video_emb, emb_pad), dim=1)
                    pymid_msk = torch.cat((video_msk, msk_pad), dim=1)
                else:
                    pymid_emb, pymid_msk = video_emb, video_msk

                pymid, pymid_msk = self.pyramid(pymid_emb, pymid_msk, return_mask=True)
                if not len(pymid) == len(pymid_msk) == len(self.strides):
                    warnings.warn(f'pyramid size mismatch: {len(pymid)} {len(pymid_msk)} {len(self.strides)}')

                point = self.generator(pymid)

                out_class = [self.class_head(e) for e in pymid]
                out_class = torch.cat(out_class, dim=1)

                out_coord = [self.coef(self.coord_head(e).exp(), i) for i, e in enumerate(pymid)]
                out_coord = torch.cat(out_coord, dim=1)

                data = dict(
                    point=point,
                    video_emb=video_emb,
                    query_emb=query_emb,
                    video_msk=video_msk,
                    pymid_msk=pymid_msk,
                    out_class=out_class,
                    out_coord=out_coord,
                    boundary=point.new_tensor(ts),
                    saliency=saliency[batch_idx].unsqueeze(0),
                    pos_clip=pos_clip[batch_idx].unsqueeze(0))

                losses = self.bundle_loss(data, dict())
                # print({k: v.item() for k, v in losses.items()})

                loss_regs.append(sum(v for v in losses.values()))
                avg_factors.append(len(ts))

            assert len(loss_regs) in (1, 2) and len(loss_regs) == len(avg_factors)

            if len(loss_regs) == 2 and loss_regs[0] > loss_regs[1]:
                loss_reg, avg_factor = loss_regs[1], avg_factors[1]
            else:
                loss_reg, avg_factor = loss_regs[0], avg_factors[0]

            if avg_factor > 0:
                outputs.loss = outputs.loss + loss_reg / avg_factor
        elif mode == 'generating':
            logits = outputs.logits[0, -1]
            if logits.argmax() == self.config.reg_token_id:
                assert self.model.norm.state.size() == (1, 1, self.config.hidden_size)

                # only consider the first video
                s, e = self.cache_vision_inds[0][0]

                # spatial merge size set to 2
                window = int(video_grid_thw[0][1] * video_grid_thw[0][2] / 4)
                assert video_grid_thw[0][0] * window == e - s

                reg_tokens = self.reg_proj(self.model.norm.state)
                # reg_tokens: num_reg_tokens * 1 * channel

                vis_tokens = self.cache_norm_state[:, s:e]
                vis_tokens = vis_tokens.transpose(-1, -2)
                vis_tokens = F.avg_pool1d(vis_tokens.float(), window, stride=window).to(vis_tokens.dtype)
                vis_tokens = vis_tokens.transpose(-1, -2)
                vis_tokens = self.vis_proj(vis_tokens).repeat(reg_tokens.size(0), 1, 1)
                # vis_tokens: num_reg_tokens * num_frames * channel

                vis_tokens = self.vis_emb(vis_tokens)
                reg_tokens = self.reg_emb(reg_tokens)
                pe = self.vis_pos(vis_tokens).to(vis_tokens.dtype)

                joint_tokens = torch.cat((vis_tokens + pe, reg_tokens), dim=1)
                for blk in self.vis_fuse:
                    joint_tokens = blk(joint_tokens)
                joint_tokens = self.vis_norm(joint_tokens)

                video_emb = joint_tokens[:, :-1]
                # video_emb: num_reg_tokens * num_frames * channel

                query_emb = joint_tokens[:, -1:]
                # query_emb: num_reg_tokens * 1 * channel

                b, t, _ = video_emb.size()
                video_msk = video_emb.new_ones(b, t)

                pymid = self.pyramid(video_emb, video_msk)
                point = self.generator(pymid)

                out_class = [self.class_head(e).sigmoid() for e in pymid]
                out_class = torch.cat(out_class, dim=1)

                out_coord = [self.coef(self.coord_head(e).exp(), i) for i, e in enumerate(pymid)]
                out_coord = torch.cat(out_coord, dim=1)

                sal = out_class[0]
                bnd = out_coord[0]

                bnd[:, 0] *= -1
                bnd *= point[:, 3, None].repeat(1, 2)
                bnd += point[:, 0, None].repeat(1, 2)
                bnd /= t
                bnd = torch.cat((bnd, sal), dim=-1)

                _, inds = bnd[:, -1].sort(descending=True)
                bnd = bnd[inds]

                # hard coding nms config here
                nms_cfg = dict(type='normal', thres=0.75)
                assert nms_cfg['type'] in ('normal', 'linear', 'gaussian')

                for i in range(bnd.size(0)):
                    max_idx = bnd[i:, -1].argmax(dim=0)
                    bnd = nncore.swap_element(bnd, i, max_idx + i)
                    iou = temporal_iou(bnd[i, None, :-1], bnd[i + 1:, :-1])[0]

                    if nms_cfg['type'] == 'normal':
                        bnd[i + 1:, -1][iou >= nms_cfg['thres']] = 0
                    elif nms_cfg['type'] == 'linear':
                        bnd[i + 1:, -1] *= 1 - iou
                    else:
                        bnd[i + 1:, -1] *= (-iou.pow(2) / nms_cfg['sigma']).exp()

                # save top-100 predictions
                self.reg.append(bnd[:100])

                # save all saliency scores
                self.sal.append(sal)

        return outputs


# set the patched model to a vision model
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES[AgentQwen2VLConfig.model_type] = 'AgentQwen2VLForConditionalGeneration'

AutoConfig.register(AgentQwen2VLConfig.model_type, AgentQwen2VLConfig)
AutoModel.register(AgentQwen2VLConfig, AgentQwen2VLForConditionalGeneration)