File size: 6,719 Bytes
ece766c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------

import copy
import torch

from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
from ldm.modules.karlo.kakao.modules.unet import PLMImUNet


class Text2ImProgressiveModel(torch.nn.Module):
    """
    A decoder that generates 64x64px images based on the text prompt.

    :param config: yaml config to define the decoder.
    :param tokenizer: tokenizer used in clip.
    """

    def __init__(
        self,
        config,
        tokenizer,
    ):
        super().__init__()

        self._conf = config
        self._model_conf = config.model.hparams
        self._diffusion_kwargs = dict(
            steps=config.diffusion.steps,
            learn_sigma=config.diffusion.learn_sigma,
            sigma_small=config.diffusion.sigma_small,
            noise_schedule=config.diffusion.noise_schedule,
            use_kl=config.diffusion.use_kl,
            predict_xstart=config.diffusion.predict_xstart,
            rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
            timestep_respacing=config.diffusion.timestep_respacing,
        )
        self._tokenizer = tokenizer

        self.model = self.create_plm_dec_model()

        cf_token, cf_mask = self.set_cf_text_tensor()
        self.register_buffer("cf_token", cf_token, persistent=False)
        self.register_buffer("cf_mask", cf_mask, persistent=False)

    @classmethod
    def load_from_checkpoint(cls, config, tokenizer, ckpt_path, strict: bool = True):
        ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]

        model = cls(config, tokenizer)
        model.load_state_dict(ckpt, strict=strict)
        return model

    def create_plm_dec_model(self):
        image_size = self._model_conf.image_size
        if self._model_conf.channel_mult == "":
            if image_size == 256:
                channel_mult = (1, 1, 2, 2, 4, 4)
            elif image_size == 128:
                channel_mult = (1, 1, 2, 3, 4)
            elif image_size == 64:
                channel_mult = (1, 2, 3, 4)
            else:
                raise ValueError(f"unsupported image size: {image_size}")
        else:
            channel_mult = tuple(
                int(ch_mult) for ch_mult in self._model_conf.channel_mult.split(",")
            )
            assert 2 ** (len(channel_mult) + 2) == image_size

        attention_ds = []
        for res in self._model_conf.attention_resolutions.split(","):
            attention_ds.append(image_size // int(res))

        return PLMImUNet(
            text_ctx=self._model_conf.text_ctx,
            xf_width=self._model_conf.xf_width,
            in_channels=3,
            model_channels=self._model_conf.num_channels,
            out_channels=6 if self._model_conf.learn_sigma else 3,
            num_res_blocks=self._model_conf.num_res_blocks,
            attention_resolutions=tuple(attention_ds),
            dropout=self._model_conf.dropout,
            channel_mult=channel_mult,
            num_heads=self._model_conf.num_heads,
            num_head_channels=self._model_conf.num_head_channels,
            num_heads_upsample=self._model_conf.num_heads_upsample,
            use_scale_shift_norm=self._model_conf.use_scale_shift_norm,
            resblock_updown=self._model_conf.resblock_updown,
            clip_dim=self._model_conf.clip_dim,
            clip_emb_mult=self._model_conf.clip_emb_mult,
            clip_emb_type=self._model_conf.clip_emb_type,
            clip_emb_drop=self._model_conf.clip_emb_drop,
        )

    def set_cf_text_tensor(self):
        return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)

    def get_sample_fn(self, timestep_respacing):
        use_ddim = timestep_respacing.startswith(("ddim", "fast"))

        diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
        diffusion_kwargs.update(timestep_respacing=timestep_respacing)
        diffusion = create_gaussian_diffusion(**diffusion_kwargs)
        sample_fn = (
            diffusion.ddim_sample_loop_progressive
            if use_ddim
            else diffusion.p_sample_loop_progressive
        )

        return sample_fn

    def forward(
        self,
        txt_feat,
        txt_feat_seq,
        tok,
        mask,
        img_feat=None,
        cf_guidance_scales=None,
        timestep_respacing=None,
    ):
        # cfg should be enabled in inference
        assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
        assert img_feat is not None

        bsz = txt_feat.shape[0]
        img_sz = self._model_conf.image_size

        def guided_model_fn(x_t, ts, **kwargs):
            half = x_t[: len(x_t) // 2]
            combined = torch.cat([half, half], dim=0)
            model_out = self.model(combined, ts, **kwargs)
            eps, rest = model_out[:, :3], model_out[:, 3:]
            cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
            half_eps = uncond_eps + cf_guidance_scales.view(-1, 1, 1, 1) * (
                cond_eps - uncond_eps
            )
            eps = torch.cat([half_eps, half_eps], dim=0)
            return torch.cat([eps, rest], dim=1)

        cf_feat = self.model.cf_param.unsqueeze(0)
        cf_feat = cf_feat.expand(bsz // 2, -1)
        feat = torch.cat([img_feat, cf_feat.to(txt_feat.device)], dim=0)

        cond = {
            "y": feat,
            "txt_feat": txt_feat,
            "txt_feat_seq": txt_feat_seq,
            "mask": mask,
        }
        sample_fn = self.get_sample_fn(timestep_respacing)
        sample_outputs = sample_fn(
            guided_model_fn,
            (bsz, 3, img_sz, img_sz),
            noise=None,
            device=txt_feat.device,
            clip_denoised=True,
            model_kwargs=cond,
        )

        for out in sample_outputs:
            sample = out["sample"]
            yield sample if cf_guidance_scales is None else sample[
                : sample.shape[0] // 2
            ]


class Text2ImModel(Text2ImProgressiveModel):
    def forward(
        self,
        txt_feat,
        txt_feat_seq,
        tok,
        mask,
        img_feat=None,
        cf_guidance_scales=None,
        timestep_respacing=None,
    ):
        last_out = None
        for out in super().forward(
            txt_feat,
            txt_feat_seq,
            tok,
            mask,
            img_feat,
            cf_guidance_scales,
            timestep_respacing,
        ):
            last_out = out
        return last_out