File size: 6,719 Bytes
ece766c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
import copy
import torch
from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
from ldm.modules.karlo.kakao.modules.unet import PLMImUNet
class Text2ImProgressiveModel(torch.nn.Module):
"""
A decoder that generates 64x64px images based on the text prompt.
:param config: yaml config to define the decoder.
:param tokenizer: tokenizer used in clip.
"""
def __init__(
self,
config,
tokenizer,
):
super().__init__()
self._conf = config
self._model_conf = config.model.hparams
self._diffusion_kwargs = dict(
steps=config.diffusion.steps,
learn_sigma=config.diffusion.learn_sigma,
sigma_small=config.diffusion.sigma_small,
noise_schedule=config.diffusion.noise_schedule,
use_kl=config.diffusion.use_kl,
predict_xstart=config.diffusion.predict_xstart,
rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
timestep_respacing=config.diffusion.timestep_respacing,
)
self._tokenizer = tokenizer
self.model = self.create_plm_dec_model()
cf_token, cf_mask = self.set_cf_text_tensor()
self.register_buffer("cf_token", cf_token, persistent=False)
self.register_buffer("cf_mask", cf_mask, persistent=False)
@classmethod
def load_from_checkpoint(cls, config, tokenizer, ckpt_path, strict: bool = True):
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
model = cls(config, tokenizer)
model.load_state_dict(ckpt, strict=strict)
return model
def create_plm_dec_model(self):
image_size = self._model_conf.image_size
if self._model_conf.channel_mult == "":
if image_size == 256:
channel_mult = (1, 1, 2, 2, 4, 4)
elif image_size == 128:
channel_mult = (1, 1, 2, 3, 4)
elif image_size == 64:
channel_mult = (1, 2, 3, 4)
else:
raise ValueError(f"unsupported image size: {image_size}")
else:
channel_mult = tuple(
int(ch_mult) for ch_mult in self._model_conf.channel_mult.split(",")
)
assert 2 ** (len(channel_mult) + 2) == image_size
attention_ds = []
for res in self._model_conf.attention_resolutions.split(","):
attention_ds.append(image_size // int(res))
return PLMImUNet(
text_ctx=self._model_conf.text_ctx,
xf_width=self._model_conf.xf_width,
in_channels=3,
model_channels=self._model_conf.num_channels,
out_channels=6 if self._model_conf.learn_sigma else 3,
num_res_blocks=self._model_conf.num_res_blocks,
attention_resolutions=tuple(attention_ds),
dropout=self._model_conf.dropout,
channel_mult=channel_mult,
num_heads=self._model_conf.num_heads,
num_head_channels=self._model_conf.num_head_channels,
num_heads_upsample=self._model_conf.num_heads_upsample,
use_scale_shift_norm=self._model_conf.use_scale_shift_norm,
resblock_updown=self._model_conf.resblock_updown,
clip_dim=self._model_conf.clip_dim,
clip_emb_mult=self._model_conf.clip_emb_mult,
clip_emb_type=self._model_conf.clip_emb_type,
clip_emb_drop=self._model_conf.clip_emb_drop,
)
def set_cf_text_tensor(self):
return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
def get_sample_fn(self, timestep_respacing):
use_ddim = timestep_respacing.startswith(("ddim", "fast"))
diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
diffusion_kwargs.update(timestep_respacing=timestep_respacing)
diffusion = create_gaussian_diffusion(**diffusion_kwargs)
sample_fn = (
diffusion.ddim_sample_loop_progressive
if use_ddim
else diffusion.p_sample_loop_progressive
)
return sample_fn
def forward(
self,
txt_feat,
txt_feat_seq,
tok,
mask,
img_feat=None,
cf_guidance_scales=None,
timestep_respacing=None,
):
# cfg should be enabled in inference
assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
assert img_feat is not None
bsz = txt_feat.shape[0]
img_sz = self._model_conf.image_size
def guided_model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.model(combined, ts, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cf_guidance_scales.view(-1, 1, 1, 1) * (
cond_eps - uncond_eps
)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
cf_feat = self.model.cf_param.unsqueeze(0)
cf_feat = cf_feat.expand(bsz // 2, -1)
feat = torch.cat([img_feat, cf_feat.to(txt_feat.device)], dim=0)
cond = {
"y": feat,
"txt_feat": txt_feat,
"txt_feat_seq": txt_feat_seq,
"mask": mask,
}
sample_fn = self.get_sample_fn(timestep_respacing)
sample_outputs = sample_fn(
guided_model_fn,
(bsz, 3, img_sz, img_sz),
noise=None,
device=txt_feat.device,
clip_denoised=True,
model_kwargs=cond,
)
for out in sample_outputs:
sample = out["sample"]
yield sample if cf_guidance_scales is None else sample[
: sample.shape[0] // 2
]
class Text2ImModel(Text2ImProgressiveModel):
def forward(
self,
txt_feat,
txt_feat_seq,
tok,
mask,
img_feat=None,
cf_guidance_scales=None,
timestep_respacing=None,
):
last_out = None
for out in super().forward(
txt_feat,
txt_feat_seq,
tok,
mask,
img_feat,
cf_guidance_scales,
timestep_respacing,
):
last_out = out
return last_out
|