File size: 3,293 Bytes
8eb415a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from enum import Enum, unique
from typing import Any

import torch
import torchvision.transforms.v2 as transforms
from diffusers import AutoencoderKL, UNet2DConditionModel, UNet2DModel
from torch import Tensor, nn
from transformers import (
    AutoImageProcessor,
    AutoModel,
    AutoProcessor,
    CLIPImageProcessor,
    CLIPVisionModel,
    SiglipImageProcessor,
    SiglipVisionModel,
)


class TryOffDiff(nn.Module):
    def __init__(self):
        super().__init__()
        self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
        self.transformer = torch.nn.TransformerEncoderLayer(d_model=768, nhead=8, batch_first=True)
        self.proj = nn.Linear(1024, 77)
        self.norm = nn.LayerNorm(768)

    def forward(self, noisy_latents, t, cond_emb):
        cond_emb = self.transformer(cond_emb)
        cond_emb = self.proj(cond_emb.transpose(1, 2))
        cond_emb = self.norm(cond_emb.transpose(1, 2))
        return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb).sample

class TryOffDiffv2(nn.Module):
    def __init__(self):
        super().__init__()
        self.unet = UNet2DConditionModel(
            sample_size=64,
            in_channels=4,
            out_channels=4,
            layers_per_block=2,
            block_out_channels=(320, 640, 1280, 1280),
            down_block_types=(
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "DownBlock2D",
            ),
            up_block_types=(
                "UpBlock2D",
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D",
            ),
            cross_attention_dim=768,
            class_embed_type=None,
            num_class_embeds=3,
        )
        # Load the pretrained weights into the custom model, skipping incompatible keys
        pretrained_state_dict = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet").state_dict()
        self.unet.load_state_dict(pretrained_state_dict, strict=False)

        self.proj = nn.Linear(1024, 77)
        self.norm = nn.LayerNorm(768)

    def forward(self, noisy_latents, t, cond_emb, class_labels):
        cond_emb = self.proj(cond_emb.transpose(1, 2))
        cond_emb = self.norm(cond_emb.transpose(1, 2))
        return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb, class_labels=class_labels).sample

class TryOffDiffv2Single(nn.Module):
    def __init__(self):
        super().__init__()
        self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
        self.proj = nn.Linear(1024, 77)
        self.norm = nn.LayerNorm(768)

    def forward(self, noisy_latents, t, cond_emb):
        cond_emb = self.proj(cond_emb.transpose(1, 2))
        cond_emb = self.norm(cond_emb.transpose(1, 2))
        return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb).sample

@unique
class ModelName(Enum):
    TryOffDiff = TryOffDiff
    TryOffDiffv2 = TryOffDiffv2
    TryOffDiffv2Single = TryOffDiffv2Single

def create_model(model_name: str, **kwargs: Any) -> Any:
    model_class = ModelName[model_name].value
    return model_class(**kwargs)