Spaces:
Runtime error
Runtime error
| import math | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from typing import NamedTuple, Tuple | |
| import torch | |
| from choices import * | |
| from config_base import BaseConfig | |
| from torch import nn | |
| from torch.nn import init | |
| from .blocks import * | |
| from .nn import timestep_embedding | |
| from .unet import * | |
| class LatentNetType(Enum): | |
| none = 'none' | |
| # injecting inputs into the hidden layers | |
| skip = 'skip' | |
| class LatentNetReturn(NamedTuple): | |
| pred: torch.Tensor = None | |
| class MLPSkipNetConfig(BaseConfig): | |
| """ | |
| default MLP for the latent DPM in the paper! | |
| """ | |
| num_channels: int | |
| skip_layers: Tuple[int] | |
| num_hid_channels: int | |
| num_layers: int | |
| num_time_emb_channels: int = 64 | |
| activation: Activation = Activation.silu | |
| use_norm: bool = True | |
| condition_bias: float = 1 | |
| dropout: float = 0 | |
| last_act: Activation = Activation.none | |
| num_time_layers: int = 2 | |
| time_last_act: bool = False | |
| def make_model(self): | |
| return MLPSkipNet(self) | |
| class MLPSkipNet(nn.Module): | |
| """ | |
| concat x to hidden layers | |
| default MLP for the latent DPM in the paper! | |
| """ | |
| def __init__(self, conf: MLPSkipNetConfig): | |
| super().__init__() | |
| self.conf = conf | |
| layers = [] | |
| for i in range(conf.num_time_layers): | |
| if i == 0: | |
| a = conf.num_time_emb_channels | |
| b = conf.num_channels | |
| else: | |
| a = conf.num_channels | |
| b = conf.num_channels | |
| layers.append(nn.Linear(a, b)) | |
| if i < conf.num_time_layers - 1 or conf.time_last_act: | |
| layers.append(conf.activation.get_act()) | |
| self.time_embed = nn.Sequential(*layers) | |
| self.layers = nn.ModuleList([]) | |
| for i in range(conf.num_layers): | |
| if i == 0: | |
| act = conf.activation | |
| norm = conf.use_norm | |
| cond = True | |
| a, b = conf.num_channels, conf.num_hid_channels | |
| dropout = conf.dropout | |
| elif i == conf.num_layers - 1: | |
| act = Activation.none | |
| norm = False | |
| cond = False | |
| a, b = conf.num_hid_channels, conf.num_channels | |
| dropout = 0 | |
| else: | |
| act = conf.activation | |
| norm = conf.use_norm | |
| cond = True | |
| a, b = conf.num_hid_channels, conf.num_hid_channels | |
| dropout = conf.dropout | |
| if i in conf.skip_layers: | |
| a += conf.num_channels | |
| self.layers.append( | |
| MLPLNAct( | |
| a, | |
| b, | |
| norm=norm, | |
| activation=act, | |
| cond_channels=conf.num_channels, | |
| use_cond=cond, | |
| condition_bias=conf.condition_bias, | |
| dropout=dropout, | |
| )) | |
| self.last_act = conf.last_act.get_act() | |
| def forward(self, x, t, **kwargs): | |
| t = timestep_embedding(t, self.conf.num_time_emb_channels) | |
| cond = self.time_embed(t) | |
| h = x | |
| for i in range(len(self.layers)): | |
| if i in self.conf.skip_layers: | |
| # injecting input into the hidden layers | |
| h = torch.cat([h, x], dim=1) | |
| h = self.layers[i].forward(x=h, cond=cond) | |
| h = self.last_act(h) | |
| return LatentNetReturn(h) | |
| class MLPLNAct(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| norm: bool, | |
| use_cond: bool, | |
| activation: Activation, | |
| cond_channels: int, | |
| condition_bias: float = 0, | |
| dropout: float = 0, | |
| ): | |
| super().__init__() | |
| self.activation = activation | |
| self.condition_bias = condition_bias | |
| self.use_cond = use_cond | |
| self.linear = nn.Linear(in_channels, out_channels) | |
| self.act = activation.get_act() | |
| if self.use_cond: | |
| self.linear_emb = nn.Linear(cond_channels, out_channels) | |
| self.cond_layers = nn.Sequential(self.act, self.linear_emb) | |
| if norm: | |
| self.norm = nn.LayerNorm(out_channels) | |
| else: | |
| self.norm = nn.Identity() | |
| if dropout > 0: | |
| self.dropout = nn.Dropout(p=dropout) | |
| else: | |
| self.dropout = nn.Identity() | |
| self.init_weights() | |
| def init_weights(self): | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| if self.activation == Activation.relu: | |
| init.kaiming_normal_(module.weight, | |
| a=0, | |
| nonlinearity='relu') | |
| elif self.activation == Activation.lrelu: | |
| init.kaiming_normal_(module.weight, | |
| a=0.2, | |
| nonlinearity='leaky_relu') | |
| elif self.activation == Activation.silu: | |
| init.kaiming_normal_(module.weight, | |
| a=0, | |
| nonlinearity='relu') | |
| else: | |
| # leave it as default | |
| pass | |
| def forward(self, x, cond=None): | |
| x = self.linear(x) | |
| if self.use_cond: | |
| # (n, c) or (n, c * 2) | |
| cond = self.cond_layers(cond) | |
| cond = (cond, None) | |
| # scale shift first | |
| x = x * (self.condition_bias + cond[0]) | |
| if cond[1] is not None: | |
| x = x + cond[1] | |
| # then norm | |
| x = self.norm(x) | |
| else: | |
| # no condition | |
| x = self.norm(x) | |
| x = self.act(x) | |
| x = self.dropout(x) | |
| return x |