Spaces:
Runtime error
Runtime error
| # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| def mask_generation( | |
| crossmap_2d_list, selfmap_2d_list=None, | |
| target_token=None, mask_scope=None, | |
| mask_target_h=64, mask_target_w=64, | |
| mask_mode=["binary"], | |
| ): | |
| if len(selfmap_2d_list) > 0: | |
| target_hw_selfmap = mask_target_h * mask_target_w | |
| selfmap_2ds = [] | |
| for i in range(len(selfmap_2d_list)): | |
| selfmap_ = selfmap_2d_list[i] | |
| selfmap_ = F.interpolate(selfmap_, size=(target_hw_selfmap, target_hw_selfmap), mode='bilinear') | |
| selfmap_2ds.append(selfmap_ ) | |
| selfmap_2ds = torch.cat(selfmap_2ds, dim=1) | |
| if "selfmap_min_max_per_channel" in mask_mode: | |
| selfmap_1ds = rearrange(selfmap_2ds, "b c h w -> b c (h w)") | |
| channel_max_self = torch.max(selfmap_1ds, dim=-1, keepdim=True)[0].unsqueeze(-1) | |
| channel_min_self = torch.min(selfmap_1ds, dim=-1, keepdim=True)[0].unsqueeze(-1) | |
| selfmap_2ds = (selfmap_2ds - channel_min_self) / (channel_max_self - channel_min_self + 1e-6) | |
| elif "selfmap_max_norm" in mask_mode: | |
| selfmap_1ds = rearrange(selfmap_2ds, "b c h w -> b c (h w)") | |
| b = selfmap_1ds.size(0) | |
| batch_max = torch.max(selfmap_1ds.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1) | |
| selfmap_2ds = selfmap_2ds / (batch_max + 1e-10) | |
| selfmap_2d = selfmap_2ds.mean(dim=1, keepdim=True) | |
| else: | |
| selfmap_2d = None | |
| crossmap_2ds = [] | |
| for i in range(len(crossmap_2d_list)): | |
| crossmap = crossmap_2d_list[i] | |
| crossmap = crossmap.mean(dim=1) # average on head dim | |
| crossmap = crossmap * target_token.unsqueeze(-1).unsqueeze(-1) # target token valid | |
| crossmap = crossmap.sum(dim=1, keepdim=True) | |
| crossmap = F.interpolate(crossmap, size=(mask_target_h, mask_target_w), mode='bilinear') | |
| crossmap_2ds.append(crossmap) | |
| crossmap_2ds = torch.cat(crossmap_2ds, dim=1) | |
| crossmap_1ds = rearrange(crossmap_2ds, "b c h w -> b c (h w)") | |
| if "max_norm" in mask_mode: | |
| crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)] | |
| if selfmap_2d is not None: | |
| crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1) | |
| b, c, n = crossmap_1ds.shape | |
| batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) | |
| crossmap_1d_avg = crossmap_1d_avg / (batch_max + 1e-6) | |
| elif "min_max_norm" in mask_mode: | |
| crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)] | |
| if selfmap_2d is not None: | |
| crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1) | |
| b, c, n = crossmap_1ds.shape | |
| batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) # NOTE unsqueeze | |
| batch_min = torch.min(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) # NOTE unsqueeze | |
| crossmap_1d_avg = (crossmap_1d_avg - batch_min) / (batch_max - batch_min + 1e-6) | |
| elif "min_max_per_channel" in mask_mode: | |
| channel_max = torch.max(crossmap_1ds, dim=-1, keepdim=True)[0] | |
| channel_min = torch.min(crossmap_1ds, dim=-1, keepdim=True)[0] | |
| crossmap_1ds = (crossmap_1ds - channel_min) / (channel_max - channel_min + 1e-6) | |
| crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)] | |
| if selfmap_2d is not None: | |
| crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1) | |
| # renormalize to 0-1 | |
| b, c, n = crossmap_1d_avg.shape | |
| batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) | |
| batch_min = torch.min(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) | |
| crossmap_1d_avg = (crossmap_1d_avg - batch_min) / (batch_max - batch_min + 1e-6) | |
| else: | |
| crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)] | |
| if "threshold" in mask_mode: | |
| threshold = 1 - mask_scope | |
| crossmap_1d_avg[crossmap_1d_avg < threshold] = 0.0 | |
| if "binary" in mask_mode: | |
| crossmap_1d_avg[crossmap_1d_avg > threshold] = 1.0 | |
| else: | |
| # topk | |
| topk_num = int(crossmap_1d_avg.size(-1) * mask_scope) | |
| sort_score, sort_order = crossmap_1d_avg.sort(descending=True, dim=-1) | |
| sort_topk = sort_order[:, :, :topk_num] | |
| sort_topk_remain = sort_order[:, :, topk_num:] | |
| crossmap_1d_avg = crossmap_1d_avg.scatter(2, sort_topk_remain, 0.) | |
| if "binary" in mask_mode: | |
| crossmap_1d_avg = crossmap_1d_avg.scatter(2, sort_topk, 1.0) | |
| crossmap_2d_avg = rearrange(crossmap_1d_avg, "b c (h w) -> b c h w", h=mask_target_h, w=mask_target_w) | |
| crossmap_2d_avg = crossmap_2d_avg | |
| output = crossmap_2d_avg.unsqueeze(1) # torch.Size([4, 1, 60, 64, 64]), The second dimension is the dimension of the number of reference images. | |
| if output.size(2) == 1: # The dimension of the layer. | |
| output = output.squeeze(2) # If there is only a single dimension, then all layers will share the same mask. | |
| return output |