Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # -------------------------------------------------------- | |
| import math | |
| from typing import Callable, Tuple | |
| import torch | |
| def do_nothing(x, mode=None): | |
| return x | |
| def bipartite_soft_matching( | |
| metric: torch.Tensor, | |
| r: int, | |
| class_token: bool = False, | |
| distill_token: bool = False, | |
| ) -> Tuple[Callable, Callable]: | |
| """ | |
| Applies ToMe with a balanced matching set (50%, 50%). | |
| Input size is [batch, tokens, channels]. | |
| r indicates the number of tokens to remove (max 50% of tokens). | |
| Extra args: | |
| - class_token: Whether or not there's a class token. | |
| - distill_token: Whether or not there's also a distillation token. | |
| When enabled, the class token and distillation tokens won't get merged. | |
| """ | |
| protected = 0 | |
| if class_token: | |
| protected += 1 | |
| if distill_token: | |
| protected += 1 | |
| # We can only reduce by a maximum of 50% tokens | |
| t = metric.shape[1] | |
| r = min(r, (t - protected) // 2) | |
| if r <= 0: | |
| return do_nothing, do_nothing | |
| with torch.no_grad(): | |
| metric = metric / metric.norm(dim=-1, keepdim=True) | |
| a, b = metric[..., ::2, :], metric[..., 1::2, :] | |
| scores = a @ b.transpose(-1, -2) | |
| if class_token: | |
| scores[..., 0, :] = -math.inf | |
| if distill_token: | |
| scores[..., :, 0] = -math.inf | |
| node_max, node_idx = scores.max(dim=-1) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) | |
| if class_token: | |
| # Sort to ensure the class token is at the start | |
| unm_idx = unm_idx.sort(dim=1)[0] | |
| def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: | |
| src, dst = x[..., ::2, :], x[..., 1::2, :] | |
| n, t1, c = src.shape | |
| unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) | |
| src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) | |
| dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) | |
| if distill_token: | |
| return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1) | |
| else: | |
| return torch.cat([unm, dst], dim=1) | |
| def unmerge(x: torch.Tensor) -> torch.Tensor: | |
| unm_len = unm_idx.shape[1] | |
| unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] | |
| n, _, c = unm.shape | |
| src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)) | |
| out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype) | |
| out[..., 1::2, :] = dst | |
| out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm) | |
| out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src) | |
| return out | |
| return merge, unmerge | |
| def kth_bipartite_soft_matching( | |
| metric: torch.Tensor, k: int | |
| ) -> Tuple[Callable, Callable]: | |
| """ | |
| Applies ToMe with the two sets as (every kth element, the rest). | |
| If n is the number of tokens, resulting number of tokens will be n // z. | |
| Input size is [batch, tokens, channels]. | |
| z indicates the stride for the first set. | |
| z = 2 is equivalent to regular bipartite_soft_matching with r = 0.5 * N | |
| """ | |
| if k <= 1: | |
| return do_nothing, do_nothing | |
| def split(x): | |
| t_rnd = (x.shape[1] // k) * k | |
| x = x[:, :t_rnd, :].view(x.shape[0], -1, k, x.shape[2]) | |
| a, b = ( | |
| x[:, :, : (k - 1), :].contiguous().view(x.shape[0], -1, x.shape[-1]), | |
| x[:, :, (k - 1), :], | |
| ) | |
| return a, b | |
| with torch.no_grad(): | |
| metric = metric / metric.norm(dim=-1, keepdim=True) | |
| a, b = split(metric) | |
| r = a.shape[1] | |
| scores = a @ b.transpose(-1, -2) | |
| _, dst_idx = scores.max(dim=-1) | |
| dst_idx = dst_idx[..., None] | |
| def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: | |
| src, dst = split(x) | |
| n, _, c = src.shape | |
| dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) | |
| return dst | |
| def unmerge(x: torch.Tensor) -> torch.Tensor: | |
| n, _, c = x.shape | |
| dst = x | |
| src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)).to(x.dtype) | |
| src = src.view(n, -1, (k - 1), c) | |
| dst = dst.view(n, -1, 1, c) | |
| out = torch.cat([src, dst], dim=-2) | |
| out = out.contiguous().view(n, -1, c) | |
| return out | |
| return merge, unmerge | |
| def random_bipartite_soft_matching( | |
| metric: torch.Tensor, r: int | |
| ) -> Tuple[Callable, Callable]: | |
| """ | |
| Applies ToMe with the two sets as (r chosen randomly, the rest). | |
| Input size is [batch, tokens, channels]. | |
| This will reduce the number of tokens by r. | |
| """ | |
| if r <= 0: | |
| return do_nothing, do_nothing | |
| with torch.no_grad(): | |
| B, N, _ = metric.shape | |
| rand_idx = torch.rand(B, N, 1, device=metric.device).argsort(dim=1) | |
| a_idx = rand_idx[:, :r, :] | |
| b_idx = rand_idx[:, r:, :] | |
| def split(x): | |
| C = x.shape[-1] | |
| a = x.gather(dim=1, index=a_idx.expand(B, r, C)) | |
| b = x.gather(dim=1, index=b_idx.expand(B, N - r, C)) | |
| return a, b | |
| metric = metric / metric.norm(dim=-1, keepdim=True) | |
| a, b = split(metric) | |
| scores = a @ b.transpose(-1, -2) | |
| _, dst_idx = scores.max(dim=-1) | |
| dst_idx = dst_idx[..., None] | |
| def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: | |
| src, dst = split(x) | |
| C = src.shape[-1] | |
| dst = dst.scatter_reduce(-2, dst_idx.expand(B, r, C), src, reduce=mode) | |
| return dst | |
| def unmerge(x: torch.Tensor) -> torch.Tensor: | |
| C = x.shape[-1] | |
| dst = x | |
| src = dst.gather(dim=-2, index=dst_idx.expand(B, r, C)) | |
| out = torch.zeros(B, N, C, device=x.device, dtype=x.dtype) | |
| out.scatter_(dim=-2, index=a_idx.expand(B, r, C), src=src) | |
| out.scatter_(dim=-2, index=b_idx.expand(B, N - r, C), src=dst) | |
| return out | |
| return merge, unmerge | |
| def merge_wavg( | |
| merge: Callable, x: torch.Tensor, size: torch.Tensor = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Applies the merge function by taking a weighted average based on token size. | |
| Returns the merged tensor and the new token sizes. | |
| """ | |
| if size is None: | |
| size = torch.ones_like(x[..., 0, None]) | |
| x = merge(x * size, mode="sum") | |
| size = merge(size, mode="sum") | |
| x = x / size | |
| return x, size | |
| def merge_source( | |
| merge: Callable, x: torch.Tensor, source: torch.Tensor = None | |
| ) -> torch.Tensor: | |
| """ | |
| For source tracking. Source is an adjacency matrix between the initial tokens and final merged groups. | |
| x is used to find out how many tokens there are in case the source is None. | |
| """ | |
| if source is None: | |
| n, t, _ = x.shape | |
| source = torch.eye(t, device=x.device)[None, ...].expand(n, t, t) | |
| source = merge(source, mode="amax") | |
| return source |