ZYMPKU's picture
v 2.0
251f521
raw
history blame
21.2 kB
from abc import abstractmethod
from typing import Iterable
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from ...modules.attention import SpatialTransformer
from ...modules.diffusionmodules.util import (
avg_pool_nd,
conv_nd,
linear,
normalization,
timestep_embedding,
zero_module,
)
from ...util import default, exists
class Timestep(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
return timestep_embedding(t, self.dim)
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(
self,
x,
emb,
t_context=None,
v_context=None
):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, t_context, v_context)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.third_up = third_up
if use_conv:
self.conv = conv_nd(
dims, self.channels, self.out_channels, 3, padding=padding
)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
t_factor = 1 if not self.third_up else 2
x = F.interpolate(
x,
(t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
mode="nearest",
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
if use_conv:
# print(f"Building a Downsample layer with {dims} dims.")
# print(
# f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
# f"kernel-size: 3, stride: {stride}, padding: {padding}"
# )
if dims == 3:
pass
# print(f" --> Downsampling third axis (time): {third_down}")
self.op = conv_nd(
dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding,
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
up=False,
down=False,
kernel_size=3,
exchange_temb_dims=False,
skip_t_emb=False
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
self.exchange_temb_dims = exchange_temb_dims
if isinstance(kernel_size, Iterable):
padding = [k // 2 for k in kernel_size]
else:
padding = kernel_size // 2
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.skip_t_emb = skip_t_emb
self.emb_out_channels = (
2 * self.out_channels if use_scale_shift_norm else self.out_channels
)
if self.skip_t_emb:
print(f"Skipping timestep embedding in {self.__class__.__name__}")
assert not self.use_scale_shift_norm
self.emb_layers = None
self.exchange_temb_dims = False
else:
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
self.emb_out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(
dims,
self.out_channels,
self.out_channels,
kernel_size,
padding=padding,
)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, kernel_size, padding=padding
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
if self.skip_t_emb:
emb_out = th.zeros_like(h)
else:
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
import seaborn as sns
import matplotlib.pyplot as plt
class UnifiedUNetModel(nn.Module):
def __init__(
self,
in_channels,
ctrl_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
save_attn_type=None,
save_attn_layers=[],
conv_resample=True,
dims=2,
use_label=None,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
transformer_depth=1,
t_context_dim=None,
v_context_dim=None,
num_attention_blocks=None,
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert (
num_head_channels != -1
), "Either num_heads or num_head_channels has to be set"
if num_head_channels == -1:
assert (
num_heads != -1
), "Either num_heads or num_head_channels has to be set"
self.in_channels = in_channels
self.ctrl_channels = ctrl_channels
self.model_channels = model_channels
self.out_channels = out_channels
transformer_depth = len(channel_mult) * [transformer_depth]
transformer_depth_middle = default(transformer_depth_middle, transformer_depth[-1])
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_label = use_label
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.use_label is not None:
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
if self.ctrl_channels > 0:
self.ctrl_block = TimestepEmbedSequential(
conv_nd(dims, ctrl_channels, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1),
nn.SiLU(),
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if (
not exists(num_attention_blocks)
or nr < num_attention_blocks[level]
):
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth[level],
t_context_dim=t_context_dim,
v_context_dim=v_context_dim,
use_linear=use_linear_in_transformer
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
down=True
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm
),
SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth_middle,
t_context_dim=t_context_dim,
v_context_dim=v_context_dim,
use_linear=use_linear_in_transformer
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm
)
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=model_channels * mult,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm
)
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if (
not exists(num_attention_blocks)
or i < num_attention_blocks[level]
):
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth[level],
t_context_dim=t_context_dim,
v_context_dim=v_context_dim,
use_linear=use_linear_in_transformer
)
)
if level and i == self.num_res_blocks[level]:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
up=True
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1))
)
# cache attn map
self.attn_type = save_attn_type
self.attn_layers = save_attn_layers
self.attn_map_cache = []
for name, module in self.named_modules():
if any([name.endswith(attn_type) for attn_type in self.attn_type]):
item = {"name": name, "heads": module.heads, "size": None, "attn_map": None}
self.attn_map_cache.append(item)
module.attn_map_cache = item
def clear_attn_map(self):
for item in self.attn_map_cache:
if item["attn_map"] is not None:
del item["attn_map"]
item["attn_map"] = None
def save_attn_map(self, attn_type="t_attn", save_name="temp", tokens=""):
attn_maps = []
for item in self.attn_map_cache:
name = item["name"]
if any([name.startswith(block) for block in self.attn_layers]) and name.endswith(attn_type):
heads = item["heads"]
attn_maps.append(item["attn_map"].detach().cpu())
attn_map = th.stack(attn_maps, dim=0)
attn_map = th.mean(attn_map, dim=0)
# attn_map: bh * n * l
bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
attn_map = attn_map.reshape((-1,heads,n,l)).mean(dim=1)
b = attn_map.shape[0]
h = w = int(n**0.5)
attn_map = attn_map.permute(0,2,1).reshape((b,l,h,w)).numpy()
attn_map_i = attn_map[-1]
l = attn_map_i.shape[0]
fig = plt.figure(figsize=(12, 8), dpi=300)
for j in range(12):
if j >= l: break
ax = fig.add_subplot(3, 4, j+1)
sns.heatmap(attn_map_i[j], square=True, xticklabels=False, yticklabels=False)
if j < len(tokens):
ax.set_title(tokens[j])
fig.savefig(f"temp/attn_map/attn_map_{save_name}.png")
plt.close()
return attn_map_i
def forward(self, x, timesteps=None, t_context=None, v_context=None, y=None, **kwargs):
assert (y is not None) == (
self.use_label is not None
), "must specify y if and only if the model is class-conditional"
self.clear_attn_map()
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.use_label is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x
if self.ctrl_channels > 0:
in_h, add_h = th.split(h, [self.in_channels, self.ctrl_channels], dim=1)
for i, module in enumerate(self.input_blocks):
if self.ctrl_channels > 0 and i == 0:
h = module(in_h, emb, t_context, v_context) + self.ctrl_block(add_h, emb, t_context, v_context)
else:
h = module(h, emb, t_context, v_context)
hs.append(h)
h = self.middle_block(h, emb, t_context, v_context)
for i, module in enumerate(self.output_blocks):
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, t_context, v_context)
h = h.type(x.dtype)
return self.out(h)