Questions about Naive Dynamic Resolution and the vision mask

by YaYaGeGe - opened

Hi Qwen team,
Great work and appreciate opensourcing it!
According to the code implementation in huggingface transformers, the image tokens firstly get reshaped into a token sequence and then being applied with conv3d which raises some disparities against to the paper. I shared the same question with this

  • PatchEmbed doesn't count more than 2 frames in time dimension by that Conv3d Layer, given the sequence input; thus it doesn't involve the information over two-frames time domain at conv step, and multi frames will increase the conv computation times linearly.
  • I think it makes more sense that applying Conv3D on 4D image tensor first, then reshaping / flattening
  • Although, no matter flattening the tokens first or patchifying (conv) first, the sequence length will increase with more frames, as seq_len = grid_t * grid_h * grid_w.

Besides, when applying the vision masks to the vision attention weights by addition, can elaborate the rationality a bit? Is it an approach to emphasize the attention inside one frame pairs but mitigate the relationship between other frames?

    attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
        for i in range(1, len(cu_seqlens)):
    attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
    attn_weights = attn_weights + attention_mask

it seems the code has been updated at the main branch, by adding -inf to 'intersection region' and zero at 'self mask' region, the mask when grid_t = 2 shows like
| 0 | 0 | -inf | -inf |
| 0 | 0 | -inf | -inf |
| -inf | -inf | 0 | 0 |
| -inf | -inf | 0 | 0 |
which mask out the attention between frames but only apply self attention of patches in each frame itself.
Does the mask setting really take effect for video understanding, without intersect attention across multiple frames? Given consecutive frames in a video clip share the same or semantically continuous contents.

        attention_mask = torch.full(
            [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
        attn_weights = attn_weights + attention_mask

Sign up or log in to comment