Spaces:
Running
on
Zero
Running
on
Zero
| #Taken from: https://github.com/dbolya/tomesd | |
| import torch | |
| from typing import Tuple, Callable | |
| import math | |
| def do_nothing(x: torch.Tensor, mode:str=None): | |
| return x | |
| def mps_gather_workaround(input, dim, index): | |
| if input.shape[-1] == 1: | |
| return torch.gather( | |
| input.unsqueeze(-1), | |
| dim - 1 if dim < 0 else dim, | |
| index.unsqueeze(-1) | |
| ).squeeze(-1) | |
| else: | |
| return torch.gather(input, dim, index) | |
| def bipartite_soft_matching_random2d(metric: torch.Tensor, | |
| w: int, h: int, sx: int, sy: int, r: int, | |
| no_rand: bool = False) -> Tuple[Callable, Callable]: | |
| """ | |
| Partitions the tokens into src and dst and merges r tokens from src to dst. | |
| Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. | |
| Args: | |
| - metric [B, N, C]: metric to use for similarity | |
| - w: image width in tokens | |
| - h: image height in tokens | |
| - sx: stride in the x dimension for dst, must divide w | |
| - sy: stride in the y dimension for dst, must divide h | |
| - r: number of tokens to remove (by merging) | |
| - no_rand: if true, disable randomness (use top left corner only) | |
| """ | |
| B, N, _ = metric.shape | |
| if r <= 0 or w == 1 or h == 1: | |
| return do_nothing, do_nothing | |
| gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather | |
| with torch.no_grad(): | |
| hsy, wsx = h // sy, w // sx | |
| # For each sy by sx kernel, randomly assign one token to be dst and the rest src | |
| if no_rand: | |
| rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64) | |
| else: | |
| rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device) | |
| # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead | |
| idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64) | |
| idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) | |
| idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx) | |
| # Image is not divisible by sx or sy so we need to move it into a new buffer | |
| if (hsy * sy) < h or (wsx * sx) < w: | |
| idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64) | |
| idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view | |
| else: | |
| idx_buffer = idx_buffer_view | |
| # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices | |
| rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1) | |
| # We're finished with these | |
| del idx_buffer, idx_buffer_view | |
| # rand_idx is currently dst|src, so split them | |
| num_dst = hsy * wsx | |
| a_idx = rand_idx[:, num_dst:, :] # src | |
| b_idx = rand_idx[:, :num_dst, :] # dst | |
| def split(x): | |
| C = x.shape[-1] | |
| src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C)) | |
| dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) | |
| return src, dst | |
| # Cosine similarity between A and B | |
| metric = metric / metric.norm(dim=-1, keepdim=True) | |
| a, b = split(metric) | |
| scores = a @ b.transpose(-1, -2) | |
| # Can't reduce more than the # tokens in src | |
| r = min(a.shape[1], r) | |
| # Find the most similar greedily | |
| node_max, node_idx = scores.max(dim=-1) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) | |
| def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: | |
| src, dst = split(x) | |
| n, t1, c = src.shape | |
| unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) | |
| src = gather(src, dim=-2, index=src_idx.expand(n, r, c)) | |
| dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) | |
| return torch.cat([unm, dst], dim=1) | |
| def unmerge(x: torch.Tensor) -> torch.Tensor: | |
| unm_len = unm_idx.shape[1] | |
| unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] | |
| _, _, c = unm.shape | |
| src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c)) | |
| # Combine back to the original shape | |
| out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) | |
| out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) | |
| out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) | |
| out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src) | |
| return out | |
| return merge, unmerge | |
| def get_functions(x, ratio, original_shape): | |
| b, c, original_h, original_w = original_shape | |
| original_tokens = original_h * original_w | |
| downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) | |
| stride_x = 2 | |
| stride_y = 2 | |
| max_downsample = 1 | |
| if downsample <= max_downsample: | |
| w = int(math.ceil(original_w / downsample)) | |
| h = int(math.ceil(original_h / downsample)) | |
| r = int(x.shape[1] * ratio) | |
| no_rand = False | |
| m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand) | |
| return m, u | |
| nothing = lambda y: y | |
| return nothing, nothing | |
| class TomePatchModel: | |
| def INPUT_TYPES(s): | |
| return {"required": { "model": ("MODEL",), | |
| "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| }} | |
| RETURN_TYPES = ("MODEL",) | |
| FUNCTION = "patch" | |
| CATEGORY = "_for_testing" | |
| def patch(self, model, ratio): | |
| self.u = None | |
| def tomesd_m(q, k, v, extra_options): | |
| #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q | |
| #however from my basic testing it seems that using q instead gives better results | |
| m, self.u = get_functions(q, ratio, extra_options["original_shape"]) | |
| return m(q), k, v | |
| def tomesd_u(n, extra_options): | |
| return self.u(n) | |
| m = model.clone() | |
| m.set_model_attn1_patch(tomesd_m) | |
| m.set_model_attn1_output_patch(tomesd_u) | |
| return (m, ) | |
| NODE_CLASS_MAPPINGS = { | |
| "TomePatchModel": TomePatchModel, | |
| } | |