YuxueYang
Upload demo
2a59fa8
raw
history blame
30.5 kB
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
from dataclasses import dataclass
import os
from os import PathLike
from typing import List, Mapping, Optional, Tuple, Union
from functools import partial
from abc import abstractmethod
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.utils.checkpoint
import torch.nn.functional as F
from einops import rearrange, repeat
from diffusers import __version__
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from diffusers.utils import (
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
BaseOutput,
logging,
_get_model_file,
_add_variant
)
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.model_loading_utils import load_state_dict
from ..common import checkpoint
from ..basics import avg_pool_nd, conv_nd, zero_module
from ..modules.attention import SpatialTransformer, TemporalTransformer, CrossAttention
from omegaconf import ListConfig, DictConfig, OmegaConf
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, context=None, batch_size=None, **kwargs):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb, batch_size=batch_size)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, **kwargs)
elif isinstance(layer, TemporalTransformer):
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size)
x = layer(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
else:
x = layer(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest')
else:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.use_conv:
x = self.conv(x)
return x
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
:param use_temporal_conv: if True, use the temporal convolution.
:param use_image_dataset: if True, the temporal parameters will not be optimized.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
use_conv=False,
up=False,
down=False,
use_temporal_conv=False,
tempspatial_aware=False
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.use_temporal_conv = use_temporal_conv
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
if self.use_temporal_conv:
self.temopral_conv = TemporalConvBlock(
self.out_channels,
self.out_channels,
dropout=0.1,
spatial_aware=tempspatial_aware
)
def forward(self, x, emb, batch_size=None):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
input_tuple = (x, emb)
if batch_size:
forward_batchsize = partial(self._forward, batch_size=batch_size)
return checkpoint(forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint)
return checkpoint(self._forward, input_tuple, self.parameters(), self.use_checkpoint)
def _forward(self, x, emb, batch_size=None):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
h = self.skip_connection(x) + h
if self.use_temporal_conv and batch_size:
h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
h = self.temopral_conv(h)
h = rearrange(h, 'b c t h w -> (b t) c h w')
return h
class TemporalConvBlock(nn.Module):
"""
Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
"""
def __init__(self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False):
super(TemporalConvBlock, self).__init__()
if out_channels is None:
out_channels = in_channels
self.in_channels = in_channels
self.out_channels = out_channels
th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
# conv layers
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_channels), nn.SiLU(),
nn.Conv3d(in_channels, out_channels, th_kernel_shape, padding=th_padding_shape))
self.conv2 = nn.Sequential(
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape))
self.conv3 = nn.Sequential(
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_channels, in_channels, th_kernel_shape, padding=th_padding_shape))
self.conv4 = nn.Sequential(
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape))
# zero out the last layer params,so the conv block is identity
nn.init.zeros_(self.conv4[-1].weight)
nn.init.zeros_(self.conv4[-1].bias)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
return identity + x
@dataclass
class UNetModelOutput(BaseOutput):
sample: torch.FloatTensor
class UNetModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0.0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
context_dim=None,
use_scale_shift_norm=False,
resblock_updown=False,
num_heads=-1,
num_head_channels=-1,
transformer_depth=1,
use_linear=False,
use_checkpoint=False,
temporal_conv=False,
tempspatial_aware=False,
temporal_attention=True,
use_relative_position=True,
use_causal_attention=False,
temporal_length=None,
addition_attention=False,
temporal_selfatt_only=True,
image_cross_attention=False,
image_cross_attention_scale_learnable=False,
masked_layer_fusion=False,
default_fps=4,
fps_condition=False,
):
super().__init__()
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.temporal_attention = temporal_attention
time_embed_dim = model_channels * 4
self.use_checkpoint = use_checkpoint
temporal_self_att_only = True
self.addition_attention = addition_attention
self.temporal_length = temporal_length
self.image_cross_attention = image_cross_attention
self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
self.default_fps = default_fps
self.fps_condition = fps_condition
## Time embedding blocks
self.time_proj = Timesteps(model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embed = TimestepEmbedding(model_channels, time_embed_dim)
if fps_condition:
self.fps_embedding = TimestepEmbedding(model_channels, time_embed_dim)
nn.init.zeros_(self.fps_embedding.linear_2.weight)
nn.init.zeros_(self.fps_embedding.linear_2.bias)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
]
)
if self.addition_attention:
self.init_attn = TimestepEmbedSequential(
TemporalTransformer(
model_channels,
n_heads=8,
d_head=num_head_channels,
depth=transformer_depth,
context_dim=context_dim,
use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only,
causal_attention=False, relative_position=use_relative_position,
temporal_length=temporal_length
)
)
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(ch, time_embed_dim, dropout,
out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers.append(
SpatialTransformer(ch, num_heads, dim_head,
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
use_checkpoint=use_checkpoint, disable_self_attn=False,
video_length=temporal_length, image_cross_attention=self.image_cross_attention,
image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
)
)
if self.temporal_attention:
layers.append(
TemporalTransformer(ch, num_heads, dim_head,
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention, relative_position=use_relative_position,
temporal_length=temporal_length
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(ch, time_embed_dim, dropout,
out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True
)
if resblock_updown
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers = [
ResBlock(ch, time_embed_dim, dropout,
dims=dims, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv
),
SpatialTransformer(ch, num_heads, dim_head,
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
use_checkpoint=use_checkpoint, disable_self_attn=False, video_length=temporal_length,
image_cross_attention=self.image_cross_attention,image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable
)
]
if self.temporal_attention:
layers.append(
TemporalTransformer(ch, num_heads, dim_head,
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention, relative_position=use_relative_position,
temporal_length=temporal_length
)
)
layers.append(
ResBlock(ch, time_embed_dim, dropout,
dims=dims, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv
)
)
## Middle Block
self.middle_block = TimestepEmbedSequential(*layers)
## Output Block
self.output_blocks = nn.ModuleList([])
self.masked_layer_fusion = masked_layer_fusion
if self.masked_layer_fusion:
self.masked_layer_fusion_norm_list = nn.ModuleList([])
self.masked_layer_fusion_attn_list = nn.ModuleList([])
self.masked_layer_fusion_out_list = nn.ModuleList([])
self.layer_feature_block_indices = []
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
input_channel = ch
ich = input_block_chans.pop()
layers = [
ResBlock(ch + ich, time_embed_dim, dropout,
out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv
)
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if self.masked_layer_fusion and i < num_res_blocks:
self.masked_layer_fusion_norm_list.append(nn.LayerNorm(input_channel))
self.masked_layer_fusion_attn_list.append(
CrossAttention(
query_dim=input_channel,
context_dim=ch,
dim_head=dim_head,
heads=num_heads,
use_xformers=False,
)
)
self.masked_layer_fusion_out_list.append(
zero_module(conv_nd(dims, input_channel, ch, 3, padding=1))
)
self.layer_feature_block_indices.append(len(self.output_blocks))
layers.append(
SpatialTransformer(ch, num_heads, dim_head,
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
use_checkpoint=use_checkpoint, disable_self_attn=False, video_length=temporal_length,
image_cross_attention=self.image_cross_attention, image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable
)
)
if self.temporal_attention:
layers.append(
TemporalTransformer(ch, num_heads, dim_head,
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention, relative_position=use_relative_position,
temporal_length=temporal_length
)
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResBlock(ch, time_embed_dim, dropout,
out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.out = nn.Sequential(
nn.GroupNorm(32, ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
def forward(self, x, timesteps, context_text, context_img=None, controls=None, layer_validity=None, fps=None, **kwargs):
b, _, t, _, _ = x.shape
t_emb = self.time_proj(timesteps).type(x.dtype)
emb = self.time_embed(t_emb)
## repeat t times for context [(b t) 77 768] & time embedding
## check if we use per-frame image conditioning
if context_img is not None: ## decompose context into text and image
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = rearrange(context_img, 'b (t l) c -> (b t) l c', t=t)
context = torch.cat([context_text, context_img], dim=1)
else:
context = context_text.repeat_interleave(repeats=t, dim=0)
emb = emb.repeat_interleave(repeats=t, dim=0)
## always in shape (b t) c h w, except for temporal layer
x = rearrange(x, 'b c t h w -> (b t) c h w')
## combine emb
if self.fps_condition:
if fps is None:
fps = torch.tensor(
[self.default_fs] * b, dtype=torch.long, device=x.device)
fps_emb = self.time_proj(fps).type(x.dtype)
fps_embed = self.fps_embedding(fps_emb)
fps_embed = fps_embed.repeat_interleave(repeats=t, dim=0)
emb = emb + fps_embed
h = x.type(self.dtype)
hs = []
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context=context, batch_size=b)
if id == 0 and self.addition_attention:
h = self.init_attn(h, emb, context=context, batch_size=b)
hs.append(h)
h = self.middle_block(h, emb, context=context, batch_size=b)
layer_fusion_idx = 0
for id, module in enumerate(self.output_blocks):
skip = hs.pop()
if controls is not None and len(controls) > 0 and id in self.layer_feature_block_indices:
layer_features = controls.pop()
feature_h, feature_w = layer_features.shape[-2:]
layer_features = rearrange(layer_features, 'b n t c h w -> (b t h w) n c')
frame_features = rearrange(h, '(b t) c h w -> (b t) (h w) c', b=b)
frame_features = self.masked_layer_fusion_norm_list[layer_fusion_idx](frame_features)
frame_features = rearrange(frame_features, '(b t) (h w) c -> (b t h w) 1 c', b=b, t=t, h=feature_h, w=feature_w)
fused_features = self.masked_layer_fusion_attn_list[layer_fusion_idx](
frame_features,
layer_features,
mask=repeat(layer_validity, "b n -> (b t h w) 1 n", t=t, h=feature_h, w=feature_w)
)
fused_features = rearrange(fused_features, '(b t h w) 1 c -> (b t) c h w', b=b, t=t, h=feature_h, w=feature_w)
fused_features = self.masked_layer_fusion_out_list[layer_fusion_idx](fused_features)
skip += fused_features
layer_fusion_idx += 1
h = torch.cat([h, skip], dim=1)
h = module(h, emb, context=context, batch_size=b)
h = h.type(x.dtype)
y = self.out(h)
# reshape back to (b c t h w)
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
return UNetModelOutput(sample=y)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, unet_additional_kwargs={}, **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "pytorch",
}
# load config
config, unused_kwargs, commit_hash = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
return_commit_hash=True,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
**kwargs,
)
for key, value in unet_additional_kwargs.items():
if isinstance(value, (ListConfig, DictConfig)):
config[key] = OmegaConf.to_container(value, resolve=True)
else:
config[key] = value
# load model
model_file = None
if use_safetensors:
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
except IOError as e:
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
if not allow_pickle:
raise
logger.warning(
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file, variant)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print(f"UNetModel loaded from {model_file} with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys.")
return model