Spaces:
Runtime error
Runtime error
| # Modified from Flux | |
| # | |
| # Copyright 2024 Black Forest Labs | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from einops import rearrange | |
| from torch import Tensor, nn | |
| def swish(x: Tensor) -> Tensor: | |
| return x * torch.sigmoid(x) | |
| class AttnBlock(nn.Module): | |
| def __init__(self, in_channels: int): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) | |
| self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) | |
| self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) | |
| self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) | |
| self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) | |
| def attention(self, h_: Tensor) -> Tensor: | |
| h_ = self.norm(h_) | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| b, c, h, w = q.shape | |
| q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() | |
| k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() | |
| v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() | |
| h_ = nn.functional.scaled_dot_product_attention(q, k, v) | |
| return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return x + self.proj_out(self.attention(x)) | |
| class ResnetBlock(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| out_channels = in_channels if out_channels is None else out_channels | |
| self.out_channels = out_channels | |
| self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| if self.in_channels != self.out_channels: | |
| self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) | |
| def forward(self, x): | |
| h = x | |
| h = self.norm1(h) | |
| h = swish(h) | |
| h = self.conv1(h) | |
| h = self.norm2(h) | |
| h = swish(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| x = self.nin_shortcut(x) | |
| return x + h | |
| class Downsample(nn.Module): | |
| def __init__(self, in_channels: int): | |
| super().__init__() | |
| # no asymmetric padding in torch conv, must do it ourselves | |
| self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) | |
| def forward(self, x: Tensor): | |
| pad = (0, 1, 0, 1) | |
| x = nn.functional.pad(x, pad, mode="constant", value=0) | |
| x = self.conv(x) | |
| return x | |
| class Upsample(nn.Module): | |
| def __init__(self, in_channels: int): | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) | |
| def forward(self, x: Tensor): | |
| x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") | |
| x = self.conv(x) | |
| return x | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| resolution: int, | |
| in_channels: int, | |
| ch: int, | |
| ch_mult: list[int], | |
| num_res_blocks: int, | |
| z_channels: int, | |
| ): | |
| super().__init__() | |
| self.ch = ch | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| self.resolution = resolution | |
| self.in_channels = in_channels | |
| # downsampling | |
| self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) | |
| curr_res = resolution | |
| in_ch_mult = (1, *tuple(ch_mult)) | |
| self.in_ch_mult = in_ch_mult | |
| self.down = nn.ModuleList() | |
| block_in = self.ch | |
| for i_level in range(self.num_resolutions): | |
| block = nn.ModuleList() | |
| attn = nn.ModuleList() | |
| block_in = ch * in_ch_mult[i_level] | |
| block_out = ch * ch_mult[i_level] | |
| for _ in range(self.num_res_blocks): | |
| block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) | |
| block_in = block_out | |
| down = nn.Module() | |
| down.block = block | |
| down.attn = attn | |
| if i_level != self.num_resolutions - 1: | |
| down.downsample = Downsample(block_in) | |
| curr_res = curr_res // 2 | |
| self.down.append(down) | |
| # middle | |
| self.mid = nn.Module() | |
| self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) | |
| self.mid.attn_1 = AttnBlock(block_in) | |
| self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) | |
| # end | |
| self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) | |
| self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) | |
| def forward(self, x: Tensor) -> Tensor: | |
| # downsampling | |
| hs = [self.conv_in(x)] | |
| for i_level in range(self.num_resolutions): | |
| for i_block in range(self.num_res_blocks): | |
| h = self.down[i_level].block[i_block](hs[-1]) | |
| if len(self.down[i_level].attn) > 0: | |
| h = self.down[i_level].attn[i_block](h) | |
| hs.append(h) | |
| if i_level != self.num_resolutions - 1: | |
| hs.append(self.down[i_level].downsample(hs[-1])) | |
| # middle | |
| h = hs[-1] | |
| h = self.mid.block_1(h) | |
| h = self.mid.attn_1(h) | |
| h = self.mid.block_2(h) | |
| # end | |
| h = self.norm_out(h) | |
| h = swish(h) | |
| h = self.conv_out(h) | |
| return h | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| ch: int, | |
| out_ch: int, | |
| ch_mult: list[int], | |
| num_res_blocks: int, | |
| in_channels: int, | |
| resolution: int, | |
| z_channels: int, | |
| ): | |
| super().__init__() | |
| self.ch = ch | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| self.resolution = resolution | |
| self.in_channels = in_channels | |
| self.ffactor = 2 ** (self.num_resolutions - 1) | |
| # compute in_ch_mult, block_in and curr_res at lowest res | |
| block_in = ch * ch_mult[self.num_resolutions - 1] | |
| curr_res = resolution // 2 ** (self.num_resolutions - 1) | |
| self.z_shape = (1, z_channels, curr_res, curr_res) | |
| # z to block_in | |
| self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) | |
| # middle | |
| self.mid = nn.Module() | |
| self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) | |
| self.mid.attn_1 = AttnBlock(block_in) | |
| self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) | |
| # upsampling | |
| self.up = nn.ModuleList() | |
| for i_level in reversed(range(self.num_resolutions)): | |
| block = nn.ModuleList() | |
| attn = nn.ModuleList() | |
| block_out = ch * ch_mult[i_level] | |
| for _ in range(self.num_res_blocks + 1): | |
| block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) | |
| block_in = block_out | |
| up = nn.Module() | |
| up.block = block | |
| up.attn = attn | |
| if i_level != 0: | |
| up.upsample = Upsample(block_in) | |
| curr_res = curr_res * 2 | |
| self.up.insert(0, up) # prepend to get consistent order | |
| # end | |
| self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) | |
| self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) | |
| def forward(self, z: Tensor) -> Tensor: | |
| # z to block_in | |
| h = self.conv_in(z) | |
| # middle | |
| h = self.mid.block_1(h) | |
| h = self.mid.attn_1(h) | |
| h = self.mid.block_2(h) | |
| # upsampling | |
| for i_level in reversed(range(self.num_resolutions)): | |
| for i_block in range(self.num_res_blocks + 1): | |
| h = self.up[i_level].block[i_block](h) | |
| if len(self.up[i_level].attn) > 0: | |
| h = self.up[i_level].attn[i_block](h) | |
| if i_level != 0: | |
| h = self.up[i_level].upsample(h) | |
| # end | |
| h = self.norm_out(h) | |
| h = swish(h) | |
| h = self.conv_out(h) | |
| return h | |
| class DiagonalGaussian(nn.Module): | |
| def __init__(self, sample: bool = True, chunk_dim: int = 1): | |
| super().__init__() | |
| self.sample = sample | |
| self.chunk_dim = chunk_dim | |
| def forward(self, z: Tensor) -> Tensor: | |
| mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) | |
| if self.sample: | |
| std = torch.exp(0.5 * logvar) | |
| return mean + std * torch.randn_like(mean) | |
| else: | |
| return mean | |
| class AutoEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| resolution: int, | |
| in_channels: int, | |
| ch: int, | |
| out_ch: int, | |
| ch_mult: list[int], | |
| num_res_blocks: int, | |
| z_channels: int, | |
| scale_factor: float, | |
| shift_factor: float, | |
| ): | |
| super().__init__() | |
| self.encoder = Encoder( | |
| resolution=resolution, | |
| in_channels=in_channels, | |
| ch=ch, | |
| ch_mult=ch_mult, | |
| num_res_blocks=num_res_blocks, | |
| z_channels=z_channels, | |
| ) | |
| self.decoder = Decoder( | |
| resolution=resolution, | |
| in_channels=in_channels, | |
| ch=ch, | |
| out_ch=out_ch, | |
| ch_mult=ch_mult, | |
| num_res_blocks=num_res_blocks, | |
| z_channels=z_channels, | |
| ) | |
| self.reg = DiagonalGaussian() | |
| self.scale_factor = scale_factor | |
| self.shift_factor = shift_factor | |
| def encode(self, x: Tensor) -> Tensor: | |
| z = self.reg(self.encoder(x)) | |
| z = self.scale_factor * (z - self.shift_factor) | |
| return z | |
| def decode(self, z: Tensor) -> Tensor: | |
| z = z / self.scale_factor + self.shift_factor | |
| return self.decoder(z) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.decode(self.encode(x)) | |