File size: 5,101 Bytes
ece766c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
import copy
import torch
from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
from ldm.modules.karlo.kakao.modules.xf import PriorTransformer
class PriorDiffusionModel(torch.nn.Module):
"""
A prior that generates clip image feature based on the text prompt.
:param config: yaml config to define the decoder.
:param tokenizer: tokenizer used in clip.
:param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance).
:param clip_std: std to noramlize the clip image feature (zero-mean, unit variance).
"""
def __init__(self, config, tokenizer, clip_mean, clip_std):
super().__init__()
self._conf = config
self._model_conf = config.model.hparams
self._diffusion_kwargs = dict(
steps=config.diffusion.steps,
learn_sigma=config.diffusion.learn_sigma,
sigma_small=config.diffusion.sigma_small,
noise_schedule=config.diffusion.noise_schedule,
use_kl=config.diffusion.use_kl,
predict_xstart=config.diffusion.predict_xstart,
rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
timestep_respacing=config.diffusion.timestep_respacing,
)
self._tokenizer = tokenizer
self.register_buffer("clip_mean", clip_mean[None, :], persistent=False)
self.register_buffer("clip_std", clip_std[None, :], persistent=False)
causal_mask = self.get_causal_mask()
self.register_buffer("causal_mask", causal_mask, persistent=False)
self.model = PriorTransformer(
text_ctx=self._model_conf.text_ctx,
xf_width=self._model_conf.xf_width,
xf_layers=self._model_conf.xf_layers,
xf_heads=self._model_conf.xf_heads,
xf_final_ln=self._model_conf.xf_final_ln,
clip_dim=self._model_conf.clip_dim,
)
cf_token, cf_mask = self.set_cf_text_tensor()
self.register_buffer("cf_token", cf_token, persistent=False)
self.register_buffer("cf_mask", cf_mask, persistent=False)
@classmethod
def load_from_checkpoint(
cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True
):
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
model = cls(config, tokenizer, clip_mean, clip_std)
model.load_state_dict(ckpt, strict=strict)
return model
def set_cf_text_tensor(self):
return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
def get_sample_fn(self, timestep_respacing):
use_ddim = timestep_respacing.startswith(("ddim", "fast"))
diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
diffusion_kwargs.update(timestep_respacing=timestep_respacing)
diffusion = create_gaussian_diffusion(**diffusion_kwargs)
sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop
return sample_fn
def get_causal_mask(self):
seq_len = self._model_conf.text_ctx + 4
mask = torch.empty(seq_len, seq_len)
mask.fill_(float("-inf"))
mask.triu_(1)
mask = mask[None, ...]
return mask
def forward(
self,
txt_feat,
txt_feat_seq,
mask,
cf_guidance_scales=None,
timestep_respacing=None,
denoised_fn=True,
):
# cfg should be enabled in inference
assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
bsz_ = txt_feat.shape[0]
bsz = bsz_ // 2
def guided_model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.model(combined, ts, **kwargs)
eps, rest = (
model_out[:, : int(x_t.shape[1])],
model_out[:, int(x_t.shape[1]) :],
)
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * (
cond_eps - uncond_eps
)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
cond = {
"text_emb": txt_feat,
"text_enc": txt_feat_seq,
"mask": mask,
"causal_mask": self.causal_mask,
}
sample_fn = self.get_sample_fn(timestep_respacing)
sample = sample_fn(
guided_model_fn,
(bsz_, self.model.clip_dim),
noise=None,
device=txt_feat.device,
clip_denoised=False,
denoised_fn=lambda x: torch.clamp(x, -10, 10),
model_kwargs=cond,
)
sample = (sample * self.clip_std) + self.clip_mean
return sample[:bsz]
|