Spaces:
Running
Running
| import math | |
| from functools import partial | |
| from torch import nn | |
| import torch | |
| class AuraFlowPatchEmbed(nn.Module): | |
| def __init__( | |
| self, | |
| height=224, | |
| width=224, | |
| patch_size=16, | |
| in_channels=3, | |
| embed_dim=768, | |
| pos_embed_max_size=None, | |
| ): | |
| super().__init__() | |
| self.num_patches = (height // patch_size) * (width // patch_size) | |
| self.pos_embed_max_size = pos_embed_max_size | |
| self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim) | |
| self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1) | |
| self.patch_size = patch_size | |
| self.height, self.width = height // patch_size, width // patch_size | |
| self.base_size = height // patch_size | |
| def forward(self, latent): | |
| batch_size, num_channels, height, width = latent.size() | |
| latent = latent.view( | |
| batch_size, | |
| num_channels, | |
| height // self.patch_size, | |
| self.patch_size, | |
| width // self.patch_size, | |
| self.patch_size, | |
| ) | |
| latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) | |
| latent = self.proj(latent) | |
| try: | |
| return latent + self.pos_embed | |
| except RuntimeError: | |
| raise RuntimeError( | |
| f"Positional embeddings are too small for the number of patches. " | |
| f"Please increase `pos_embed_max_size` to at least {self.num_patches}." | |
| ) | |
| # comfy | |
| # def apply_pos_embeds(self, x, h, w): | |
| # h = (h + 1) // self.patch_size | |
| # w = (w + 1) // self.patch_size | |
| # max_dim = max(h, w) | |
| # | |
| # cur_dim = self.h_max | |
| # pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype) | |
| # | |
| # if max_dim > cur_dim: | |
| # pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, | |
| # -1) | |
| # cur_dim = max_dim | |
| # | |
| # from_h = (cur_dim - h) // 2 | |
| # from_w = (cur_dim - w) // 2 | |
| # pos_encoding = pos_encoding[:, from_h:from_h + h, from_w:from_w + w] | |
| # return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1]) | |
| # def patchify(self, x): | |
| # B, C, H, W = x.size() | |
| # pad_h = (self.patch_size - H % self.patch_size) % self.patch_size | |
| # pad_w = (self.patch_size - W % self.patch_size) % self.patch_size | |
| # | |
| # x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect') | |
| # x = x.view( | |
| # B, | |
| # C, | |
| # (H + 1) // self.patch_size, | |
| # self.patch_size, | |
| # (W + 1) // self.patch_size, | |
| # self.patch_size, | |
| # ) | |
| # x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) | |
| # return x | |
| def patch_auraflow_pos_embed(pos_embed): | |
| # we need to hijack the forward and replace with a custom one. Self is the model | |
| def new_forward(self, latent): | |
| batch_size, num_channels, height, width = latent.size() | |
| # add padding to the latent to make it match pos_embed | |
| latent_size = height * width * num_channels / 16 # todo check where 16 comes from? | |
| pos_embed_size = self.pos_embed.shape[1] | |
| if latent_size < pos_embed_size: | |
| total_padding = int(pos_embed_size - math.floor(latent_size)) | |
| total_padding = total_padding // 16 | |
| pad_height = total_padding // 2 | |
| pad_width = total_padding - pad_height | |
| # mirror padding on the right side | |
| padding = (0, pad_width, 0, pad_height) | |
| latent = torch.nn.functional.pad(latent, padding, mode='reflect') | |
| elif latent_size > pos_embed_size: | |
| amount_to_remove = latent_size - pos_embed_size | |
| latent = latent[:, :, :-amount_to_remove] | |
| batch_size, num_channels, height, width = latent.size() | |
| latent = latent.view( | |
| batch_size, | |
| num_channels, | |
| height // self.patch_size, | |
| self.patch_size, | |
| width // self.patch_size, | |
| self.patch_size, | |
| ) | |
| latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) | |
| latent = self.proj(latent) | |
| try: | |
| return latent + self.pos_embed | |
| except RuntimeError: | |
| raise RuntimeError( | |
| f"Positional embeddings are too small for the number of patches. " | |
| f"Please increase `pos_embed_max_size` to at least {self.num_patches}." | |
| ) | |
| pos_embed.forward = partial(new_forward, pos_embed) | |