tryoffdiff / model.py
rizavelioglu
v2
8eb415a
raw
history blame
3.29 kB
from enum import Enum, unique
from typing import Any
import torch
import torchvision.transforms.v2 as transforms
from diffusers import AutoencoderKL, UNet2DConditionModel, UNet2DModel
from torch import Tensor, nn
from transformers import (
AutoImageProcessor,
AutoModel,
AutoProcessor,
CLIPImageProcessor,
CLIPVisionModel,
SiglipImageProcessor,
SiglipVisionModel,
)
class TryOffDiff(nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
self.transformer = torch.nn.TransformerEncoderLayer(d_model=768, nhead=8, batch_first=True)
self.proj = nn.Linear(1024, 77)
self.norm = nn.LayerNorm(768)
def forward(self, noisy_latents, t, cond_emb):
cond_emb = self.transformer(cond_emb)
cond_emb = self.proj(cond_emb.transpose(1, 2))
cond_emb = self.norm(cond_emb.transpose(1, 2))
return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb).sample
class TryOffDiffv2(nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel(
sample_size=64,
in_channels=4,
out_channels=4,
layers_per_block=2,
block_out_channels=(320, 640, 1280, 1280),
down_block_types=(
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
cross_attention_dim=768,
class_embed_type=None,
num_class_embeds=3,
)
# Load the pretrained weights into the custom model, skipping incompatible keys
pretrained_state_dict = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet").state_dict()
self.unet.load_state_dict(pretrained_state_dict, strict=False)
self.proj = nn.Linear(1024, 77)
self.norm = nn.LayerNorm(768)
def forward(self, noisy_latents, t, cond_emb, class_labels):
cond_emb = self.proj(cond_emb.transpose(1, 2))
cond_emb = self.norm(cond_emb.transpose(1, 2))
return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb, class_labels=class_labels).sample
class TryOffDiffv2Single(nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
self.proj = nn.Linear(1024, 77)
self.norm = nn.LayerNorm(768)
def forward(self, noisy_latents, t, cond_emb):
cond_emb = self.proj(cond_emb.transpose(1, 2))
cond_emb = self.norm(cond_emb.transpose(1, 2))
return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb).sample
@unique
class ModelName(Enum):
TryOffDiff = TryOffDiff
TryOffDiffv2 = TryOffDiffv2
TryOffDiffv2Single = TryOffDiffv2Single
def create_model(model_name: str, **kwargs: Any) -> Any:
model_class = ModelName[model_name].value
return model_class(**kwargs)