File size: 5,101 Bytes
ece766c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------

import copy
import torch

from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
from ldm.modules.karlo.kakao.modules.xf import PriorTransformer


class PriorDiffusionModel(torch.nn.Module):
    """
    A prior that generates clip image feature based on the text prompt.

    :param config: yaml config to define the decoder.
    :param tokenizer: tokenizer used in clip.
    :param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance).
    :param clip_std: std to noramlize the clip image feature (zero-mean, unit variance).
    """

    def __init__(self, config, tokenizer, clip_mean, clip_std):
        super().__init__()

        self._conf = config
        self._model_conf = config.model.hparams
        self._diffusion_kwargs = dict(
            steps=config.diffusion.steps,
            learn_sigma=config.diffusion.learn_sigma,
            sigma_small=config.diffusion.sigma_small,
            noise_schedule=config.diffusion.noise_schedule,
            use_kl=config.diffusion.use_kl,
            predict_xstart=config.diffusion.predict_xstart,
            rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
            timestep_respacing=config.diffusion.timestep_respacing,
        )
        self._tokenizer = tokenizer

        self.register_buffer("clip_mean", clip_mean[None, :], persistent=False)
        self.register_buffer("clip_std", clip_std[None, :], persistent=False)

        causal_mask = self.get_causal_mask()
        self.register_buffer("causal_mask", causal_mask, persistent=False)

        self.model = PriorTransformer(
            text_ctx=self._model_conf.text_ctx,
            xf_width=self._model_conf.xf_width,
            xf_layers=self._model_conf.xf_layers,
            xf_heads=self._model_conf.xf_heads,
            xf_final_ln=self._model_conf.xf_final_ln,
            clip_dim=self._model_conf.clip_dim,
        )

        cf_token, cf_mask = self.set_cf_text_tensor()
        self.register_buffer("cf_token", cf_token, persistent=False)
        self.register_buffer("cf_mask", cf_mask, persistent=False)

    @classmethod
    def load_from_checkpoint(
        cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True
    ):
        ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]

        model = cls(config, tokenizer, clip_mean, clip_std)
        model.load_state_dict(ckpt, strict=strict)
        return model

    def set_cf_text_tensor(self):
        return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)

    def get_sample_fn(self, timestep_respacing):
        use_ddim = timestep_respacing.startswith(("ddim", "fast"))

        diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
        diffusion_kwargs.update(timestep_respacing=timestep_respacing)
        diffusion = create_gaussian_diffusion(**diffusion_kwargs)
        sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop

        return sample_fn

    def get_causal_mask(self):
        seq_len = self._model_conf.text_ctx + 4
        mask = torch.empty(seq_len, seq_len)
        mask.fill_(float("-inf"))
        mask.triu_(1)
        mask = mask[None, ...]
        return mask

    def forward(
        self,
        txt_feat,
        txt_feat_seq,
        mask,
        cf_guidance_scales=None,
        timestep_respacing=None,
        denoised_fn=True,
    ):
        # cfg should be enabled in inference
        assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)

        bsz_ = txt_feat.shape[0]
        bsz = bsz_ // 2

        def guided_model_fn(x_t, ts, **kwargs):
            half = x_t[: len(x_t) // 2]
            combined = torch.cat([half, half], dim=0)
            model_out = self.model(combined, ts, **kwargs)
            eps, rest = (
                model_out[:, : int(x_t.shape[1])],
                model_out[:, int(x_t.shape[1]) :],
            )
            cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
            half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * (
                cond_eps - uncond_eps
            )
            eps = torch.cat([half_eps, half_eps], dim=0)
            return torch.cat([eps, rest], dim=1)

        cond = {
            "text_emb": txt_feat,
            "text_enc": txt_feat_seq,
            "mask": mask,
            "causal_mask": self.causal_mask,
        }
        sample_fn = self.get_sample_fn(timestep_respacing)
        sample = sample_fn(
            guided_model_fn,
            (bsz_, self.model.clip_dim),
            noise=None,
            device=txt_feat.device,
            clip_denoised=False,
            denoised_fn=lambda x: torch.clamp(x, -10, 10),
            model_kwargs=cond,
        )
        sample = (sample * self.clip_std) + self.clip_mean

        return sample[:bsz]