|
|
|
|
|
|
|
|
|
|
|
import os |
|
import logging |
|
import torch |
|
|
|
from omegaconf import OmegaConf |
|
|
|
from ldm.modules.karlo.kakao.models.clip import CustomizedCLIP, CustomizedTokenizer |
|
from ldm.modules.karlo.kakao.models.prior_model import PriorDiffusionModel |
|
from ldm.modules.karlo.kakao.models.decoder_model import Text2ImProgressiveModel |
|
from ldm.modules.karlo.kakao.models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel |
|
|
|
|
|
SAMPLING_CONF = { |
|
"default": { |
|
"prior_sm": "25", |
|
"prior_n_samples": 1, |
|
"prior_cf_scale": 4.0, |
|
"decoder_sm": "50", |
|
"decoder_cf_scale": 8.0, |
|
"sr_sm": "7", |
|
}, |
|
"fast": { |
|
"prior_sm": "25", |
|
"prior_n_samples": 1, |
|
"prior_cf_scale": 4.0, |
|
"decoder_sm": "25", |
|
"decoder_cf_scale": 8.0, |
|
"sr_sm": "7", |
|
}, |
|
} |
|
|
|
CKPT_PATH = { |
|
"prior": "prior-ckpt-step=01000000-of-01000000.ckpt", |
|
"decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt", |
|
"sr_256": "improved-sr-ckpt-step=1.2M.ckpt", |
|
} |
|
|
|
|
|
class BaseSampler: |
|
_PRIOR_CLASS = PriorDiffusionModel |
|
_DECODER_CLASS = Text2ImProgressiveModel |
|
_SR256_CLASS = ImprovedSupRes64to256ProgressiveModel |
|
|
|
def __init__( |
|
self, |
|
root_dir: str, |
|
sampling_type: str = "fast", |
|
): |
|
self._root_dir = root_dir |
|
|
|
sampling_type = SAMPLING_CONF[sampling_type] |
|
self._prior_sm = sampling_type["prior_sm"] |
|
self._prior_n_samples = sampling_type["prior_n_samples"] |
|
self._prior_cf_scale = sampling_type["prior_cf_scale"] |
|
|
|
assert self._prior_n_samples == 1 |
|
|
|
self._decoder_sm = sampling_type["decoder_sm"] |
|
self._decoder_cf_scale = sampling_type["decoder_cf_scale"] |
|
|
|
self._sr_sm = sampling_type["sr_sm"] |
|
|
|
def __repr__(self): |
|
line = "" |
|
line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n" |
|
line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n" |
|
line += f"SR(64->256), sampling method: {self._sr_sm}" |
|
|
|
return line |
|
|
|
def load_clip(self, clip_path: str): |
|
clip = CustomizedCLIP.load_from_checkpoint( |
|
os.path.join(self._root_dir, clip_path) |
|
) |
|
clip = torch.jit.script(clip) |
|
clip.cuda() |
|
clip.eval() |
|
|
|
self._clip = clip |
|
self._tokenizer = CustomizedTokenizer() |
|
|
|
def load_prior( |
|
self, |
|
ckpt_path: str, |
|
clip_stat_path: str, |
|
prior_config: str = "configs/prior_1B_vit_l.yaml" |
|
): |
|
logging.info(f"Loading prior: {ckpt_path}") |
|
|
|
config = OmegaConf.load(prior_config) |
|
clip_mean, clip_std = torch.load( |
|
os.path.join(self._root_dir, clip_stat_path), map_location="cpu" |
|
) |
|
|
|
prior = self._PRIOR_CLASS.load_from_checkpoint( |
|
config, |
|
self._tokenizer, |
|
clip_mean, |
|
clip_std, |
|
os.path.join(self._root_dir, ckpt_path), |
|
strict=True, |
|
) |
|
prior.cuda() |
|
prior.eval() |
|
logging.info("done.") |
|
|
|
self._prior = prior |
|
|
|
def load_decoder(self, ckpt_path: str, decoder_config: str = "configs/decoder_900M_vit_l.yaml"): |
|
logging.info(f"Loading decoder: {ckpt_path}") |
|
|
|
config = OmegaConf.load(decoder_config) |
|
decoder = self._DECODER_CLASS.load_from_checkpoint( |
|
config, |
|
self._tokenizer, |
|
os.path.join(self._root_dir, ckpt_path), |
|
strict=True, |
|
) |
|
decoder.cuda() |
|
decoder.eval() |
|
logging.info("done.") |
|
|
|
self._decoder = decoder |
|
|
|
def load_sr_64_256(self, ckpt_path: str, sr_config: str = "configs/improved_sr_64_256_1.4B.yaml"): |
|
logging.info(f"Loading SR(64->256): {ckpt_path}") |
|
|
|
config = OmegaConf.load(sr_config) |
|
sr = self._SR256_CLASS.load_from_checkpoint( |
|
config, os.path.join(self._root_dir, ckpt_path), strict=True |
|
) |
|
sr.cuda() |
|
sr.eval() |
|
logging.info("done.") |
|
|
|
self._sr_64_256 = sr |