Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,293 Bytes
8eb415a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from enum import Enum, unique
from typing import Any
import torch
import torchvision.transforms.v2 as transforms
from diffusers import AutoencoderKL, UNet2DConditionModel, UNet2DModel
from torch import Tensor, nn
from transformers import (
AutoImageProcessor,
AutoModel,
AutoProcessor,
CLIPImageProcessor,
CLIPVisionModel,
SiglipImageProcessor,
SiglipVisionModel,
)
class TryOffDiff(nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
self.transformer = torch.nn.TransformerEncoderLayer(d_model=768, nhead=8, batch_first=True)
self.proj = nn.Linear(1024, 77)
self.norm = nn.LayerNorm(768)
def forward(self, noisy_latents, t, cond_emb):
cond_emb = self.transformer(cond_emb)
cond_emb = self.proj(cond_emb.transpose(1, 2))
cond_emb = self.norm(cond_emb.transpose(1, 2))
return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb).sample
class TryOffDiffv2(nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel(
sample_size=64,
in_channels=4,
out_channels=4,
layers_per_block=2,
block_out_channels=(320, 640, 1280, 1280),
down_block_types=(
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
cross_attention_dim=768,
class_embed_type=None,
num_class_embeds=3,
)
# Load the pretrained weights into the custom model, skipping incompatible keys
pretrained_state_dict = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet").state_dict()
self.unet.load_state_dict(pretrained_state_dict, strict=False)
self.proj = nn.Linear(1024, 77)
self.norm = nn.LayerNorm(768)
def forward(self, noisy_latents, t, cond_emb, class_labels):
cond_emb = self.proj(cond_emb.transpose(1, 2))
cond_emb = self.norm(cond_emb.transpose(1, 2))
return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb, class_labels=class_labels).sample
class TryOffDiffv2Single(nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
self.proj = nn.Linear(1024, 77)
self.norm = nn.LayerNorm(768)
def forward(self, noisy_latents, t, cond_emb):
cond_emb = self.proj(cond_emb.transpose(1, 2))
cond_emb = self.norm(cond_emb.transpose(1, 2))
return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb).sample
@unique
class ModelName(Enum):
TryOffDiff = TryOffDiff
TryOffDiffv2 = TryOffDiffv2
TryOffDiffv2Single = TryOffDiffv2Single
def create_model(model_name: str, **kwargs: Any) -> Any:
model_class = ModelName[model_name].value
return model_class(**kwargs) |