Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import pathlib | |
import random | |
import sys | |
from typing import Callable | |
import clip | |
import einops | |
import numpy as np | |
import PIL.Image | |
import torch | |
from huggingface_hub import snapshot_download | |
repo_dir = pathlib.Path(__file__).parent | |
submodule_dir = repo_dir / 'unidiffuser' | |
sys.path.append(submodule_dir.as_posix()) | |
import utils | |
from configs.sample_unidiffuser_v1 import get_config | |
from dpm_solver_pp import DPM_Solver, NoiseScheduleVP | |
from libs.autoencoder import FrozenAutoencoderKL | |
from libs.autoencoder import get_model as get_autoencoder | |
from libs.caption_decoder import CaptionDecoder | |
from libs.clip import FrozenCLIPEmbedder | |
model_dir = repo_dir / 'models' | |
if not model_dir.exists(): | |
snapshot_download('thu-ml/unidiffuser-v1', | |
repo_type='model', | |
local_dir=model_dir) | |
def stable_diffusion_beta_schedule(linear_start=0.00085, | |
linear_end=0.0120, | |
n_timestep=1000): | |
_betas = (torch.linspace(linear_start**0.5, | |
linear_end**0.5, | |
n_timestep, | |
dtype=torch.float64)**2) | |
return _betas.numpy() | |
class Model: | |
def __init__(self): | |
self.device = torch.device( | |
'cuda:0' if torch.cuda.is_available() else 'cpu') | |
self.config = get_config() | |
self.nnet = self.load_model() | |
self.caption_decoder = CaptionDecoder(device=self.device, | |
**self.config.caption_decoder) | |
self.clip_text_model = self.load_clip_text_model() | |
self.autoencoder = self.load_autoencoder() | |
self.clip_img_model, self.clip_img_model_preprocess = clip.load( | |
'ViT-B/32', device=self.device, jit=False) | |
self.empty_context = self.clip_text_model.encode([''])[0] | |
self.betas = stable_diffusion_beta_schedule() | |
self.N = len(self.betas) | |
def use_caption_decoder(self) -> bool: | |
return (self.config.text_dim < self.config.clip_text_dim | |
or self.config.mode != 't2i') | |
def load_model(self, | |
model_path: str = 'models/uvit_v1.pth') -> torch.nn.Module: | |
model = utils.get_nnet(**self.config.nnet) | |
model.load_state_dict(torch.load(model_path, map_location='cpu')) | |
model.to(self.device) | |
model.eval() | |
return model | |
def load_clip_text_model(self) -> FrozenCLIPEmbedder: | |
clip_text_model = FrozenCLIPEmbedder(device=self.device) | |
clip_text_model.to(self.device) | |
clip_text_model.eval() | |
return clip_text_model | |
def load_autoencoder(self) -> FrozenAutoencoderKL: | |
autoencoder = get_autoencoder(**self.config.autoencoder) | |
autoencoder.to(self.device) | |
return autoencoder | |
def split(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
C, H, W = self.config.z_shape | |
z_dim = C * H * W | |
z, clip_img = x.split([z_dim, self.config.clip_img_dim], dim=1) | |
z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W) | |
clip_img = einops.rearrange(clip_img, | |
'B (L D) -> B L D', | |
L=1, | |
D=self.config.clip_img_dim) | |
return z, clip_img | |
def combine(z, clip_img): | |
z = einops.rearrange(z, 'B C H W -> B (C H W)') | |
clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)') | |
return torch.concat([z, clip_img], dim=-1) | |
def t2i_nnet( | |
self, x, timesteps, text | |
): # text is the low dimension version of the text clip embedding | |
""" | |
1. calculate the conditional model output | |
2. calculate unconditional model output | |
config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string | |
config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method | |
3. return linear combination of conditional output and unconditional output | |
""" | |
z, clip_img = self.split(x) | |
t_text = torch.zeros(timesteps.size(0), | |
dtype=torch.int, | |
device=self.device) | |
z_out, clip_img_out, text_out = self.nnet( | |
z, | |
clip_img, | |
text=text, | |
t_img=timesteps, | |
t_text=t_text, | |
data_type=torch.zeros_like( | |
t_text, device=self.device, dtype=torch.int) + | |
self.config.data_type) | |
x_out = self.combine(z_out, clip_img_out) | |
if self.config.sample.scale == 0.: | |
return x_out | |
if self.config.sample.t2i_cfg_mode == 'empty_token': | |
_empty_context = einops.repeat(self.empty_context, | |
'L D -> B L D', | |
B=x.size(0)) | |
if self.use_caption_decoder: | |
_empty_context = self.caption_decoder.encode_prefix( | |
_empty_context) | |
z_out_uncond, clip_img_out_uncond, text_out_uncond = self.nnet( | |
z, | |
clip_img, | |
text=_empty_context, | |
t_img=timesteps, | |
t_text=t_text, | |
data_type=torch.zeros_like( | |
t_text, device=self.device, dtype=torch.int) + | |
self.config.data_type) | |
x_out_uncond = self.combine(z_out_uncond, clip_img_out_uncond) | |
elif self.config.sample.t2i_cfg_mode == 'true_uncond': | |
text_N = torch.randn_like(text) # 3 other possible choices | |
z_out_uncond, clip_img_out_uncond, text_out_uncond = self.nnet( | |
z, | |
clip_img, | |
text=text_N, | |
t_img=timesteps, | |
t_text=torch.ones_like(timesteps) * self.N, | |
data_type=torch.zeros_like( | |
t_text, device=self.device, dtype=torch.int) + | |
self.config.data_type) | |
x_out_uncond = self.combine(z_out_uncond, clip_img_out_uncond) | |
else: | |
raise NotImplementedError | |
return x_out + self.config.sample.scale * (x_out - x_out_uncond) | |
def i_nnet(self, x, timesteps): | |
z, clip_img = self.split(x) | |
text = torch.randn(x.size(0), | |
77, | |
self.config.text_dim, | |
device=self.device) | |
t_text = torch.ones_like(timesteps) * self.N | |
z_out, clip_img_out, text_out = self.nnet( | |
z, | |
clip_img, | |
text=text, | |
t_img=timesteps, | |
t_text=t_text, | |
data_type=torch.zeros_like( | |
t_text, device=self.device, dtype=torch.int) + | |
self.config.data_type) | |
x_out = self.combine(z_out, clip_img_out) | |
return x_out | |
def t_nnet(self, x, timesteps): | |
z = torch.randn(x.size(0), *self.config.z_shape, device=self.device) | |
clip_img = torch.randn(x.size(0), | |
1, | |
self.config.clip_img_dim, | |
device=self.device) | |
z_out, clip_img_out, text_out = self.nnet( | |
z, | |
clip_img, | |
text=x, | |
t_img=torch.ones_like(timesteps) * self.N, | |
t_text=timesteps, | |
data_type=torch.zeros_like( | |
timesteps, device=self.device, dtype=torch.int) + | |
self.config.data_type) | |
return text_out | |
def i2t_nnet(self, x, timesteps, z, clip_img): | |
""" | |
1. calculate the conditional model output | |
2. calculate unconditional model output | |
3. return linear combination of conditional output and unconditional output | |
""" | |
t_img = torch.zeros(timesteps.size(0), | |
dtype=torch.int, | |
device=self.device) | |
z_out, clip_img_out, text_out = self.nnet( | |
z, | |
clip_img, | |
text=x, | |
t_img=t_img, | |
t_text=timesteps, | |
data_type=torch.zeros_like( | |
t_img, device=self.device, dtype=torch.int) + | |
self.config.data_type) | |
if self.config.sample.scale == 0.: | |
return text_out | |
z_N = torch.randn_like(z) # 3 other possible choices | |
clip_img_N = torch.randn_like(clip_img) | |
z_out_uncond, clip_img_out_uncond, text_out_uncond = self.nnet( | |
z_N, | |
clip_img_N, | |
text=x, | |
t_img=torch.ones_like(timesteps) * self.N, | |
t_text=timesteps, | |
data_type=torch.zeros_like( | |
timesteps, device=self.device, dtype=torch.int) + | |
self.config.data_type) | |
return text_out + self.config.sample.scale * (text_out - | |
text_out_uncond) | |
def split_joint(self, x): | |
C, H, W = self.config.z_shape | |
z_dim = C * H * W | |
z, clip_img, text = x.split( | |
[z_dim, self.config.clip_img_dim, 77 * self.config.text_dim], | |
dim=1) | |
z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W) | |
clip_img = einops.rearrange(clip_img, | |
'B (L D) -> B L D', | |
L=1, | |
D=self.config.clip_img_dim) | |
text = einops.rearrange(text, | |
'B (L D) -> B L D', | |
L=77, | |
D=self.config.text_dim) | |
return z, clip_img, text | |
def combine_joint(z: torch.Tensor, clip_img: torch.Tensor, | |
text: torch.Tensor) -> torch.Tensor: | |
z = einops.rearrange(z, 'B C H W -> B (C H W)') | |
clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)') | |
text = einops.rearrange(text, 'B L D -> B (L D)') | |
return torch.concat([z, clip_img, text], dim=-1) | |
def joint_nnet(self, x, timesteps): | |
z, clip_img, text = self.split_joint(x) | |
z_out, clip_img_out, text_out = self.nnet( | |
z, | |
clip_img, | |
text=text, | |
t_img=timesteps, | |
t_text=timesteps, | |
data_type=torch.zeros_like( | |
timesteps, device=self.device, dtype=torch.int) + | |
self.config.data_type) | |
x_out = self.combine_joint(z_out, clip_img_out, text_out) | |
if self.config.sample.scale == 0.: | |
return x_out | |
z_noise = torch.randn(x.size(0), | |
*self.config.z_shape, | |
device=self.device) | |
clip_img_noise = torch.randn(x.size(0), | |
1, | |
self.config.clip_img_dim, | |
device=self.device) | |
text_noise = torch.randn(x.size(0), | |
77, | |
self.config.text_dim, | |
device=self.device) | |
_, _, text_out_uncond = self.nnet( | |
z_noise, | |
clip_img_noise, | |
text=text, | |
t_img=torch.ones_like(timesteps) * self.N, | |
t_text=timesteps, | |
data_type=torch.zeros_like( | |
timesteps, device=self.device, dtype=torch.int) + | |
self.config.data_type) | |
z_out_uncond, clip_img_out_uncond, _ = self.nnet( | |
z, | |
clip_img, | |
text=text_noise, | |
t_img=timesteps, | |
t_text=torch.ones_like(timesteps) * self.N, | |
data_type=torch.zeros_like( | |
timesteps, device=self.device, dtype=torch.int) + | |
self.config.data_type) | |
x_out_uncond = self.combine_joint(z_out_uncond, clip_img_out_uncond, | |
text_out_uncond) | |
return x_out + self.config.sample.scale * (x_out - x_out_uncond) | |
def encode(self, _batch): | |
return self.autoencoder.encode(_batch) | |
def decode(self, _batch): | |
return self.autoencoder.decode(_batch) | |
def prepare_contexts( | |
self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
resolution = self.config.z_shape[-1] * 8 | |
contexts = torch.randn(self.config.n_samples, 77, | |
self.config.clip_text_dim).to(self.device) | |
img_contexts = torch.randn(self.config.n_samples, | |
2 * self.config.z_shape[0], | |
self.config.z_shape[1], | |
self.config.z_shape[2]) | |
clip_imgs = torch.randn(self.config.n_samples, 1, | |
self.config.clip_img_dim) | |
if self.config.mode in ['t2i', 't2i2t']: | |
prompts = [self.config.prompt] * self.config.n_samples | |
contexts = self.clip_text_model.encode(prompts) | |
elif self.config.mode in ['i2t', 'i2t2i']: | |
img_contexts = [] | |
clip_imgs = [] | |
def get_img_feature(image): | |
image = np.array(image).astype(np.uint8) | |
image = utils.center_crop(resolution, resolution, image) | |
clip_img_feature = self.clip_img_model.encode_image( | |
self.clip_img_model_preprocess( | |
PIL.Image.fromarray(image)).unsqueeze(0).to( | |
self.device)) | |
image = (image / 127.5 - 1.0).astype(np.float32) | |
image = einops.rearrange(image, 'h w c -> 1 c h w') | |
image = torch.tensor(image, device=self.device) | |
moments = self.autoencoder.encode_moments(image) | |
return clip_img_feature, moments | |
image = PIL.Image.open(self.config.img).convert('RGB') | |
clip_img, img_context = get_img_feature(image) | |
img_contexts.append(img_context) | |
clip_imgs.append(clip_img) | |
img_contexts = img_contexts * self.config.n_samples | |
clip_imgs = clip_imgs * self.config.n_samples | |
img_contexts = torch.concat(img_contexts, dim=0) | |
clip_imgs = torch.stack(clip_imgs, dim=0) | |
return contexts, img_contexts, clip_imgs | |
def unpreprocess(v: torch.Tensor) -> torch.Tensor: # to B C H W and [0, 1] | |
v = 0.5 * (v + 1.) | |
v.clamp_(0., 1.) | |
return v | |
def get_sample_fn(self, _n_samples: int) -> Callable: | |
def sample_fn(mode: str, **kwargs): | |
_z_init = torch.randn(_n_samples, | |
*self.config.z_shape, | |
device=self.device) | |
_clip_img_init = torch.randn(_n_samples, | |
1, | |
self.config.clip_img_dim, | |
device=self.device) | |
_text_init = torch.randn(_n_samples, | |
77, | |
self.config.text_dim, | |
device=self.device) | |
if mode == 'joint': | |
_x_init = self.combine_joint(_z_init, _clip_img_init, | |
_text_init) | |
elif mode in ['t2i', 'i']: | |
_x_init = self.combine(_z_init, _clip_img_init) | |
elif mode in ['i2t', 't']: | |
_x_init = _text_init | |
noise_schedule = NoiseScheduleVP(schedule='discrete', | |
betas=torch.tensor( | |
self.betas, | |
device=self.device).float()) | |
def model_fn(x, t_continuous): | |
t = t_continuous * self.N | |
if mode == 'joint': | |
return self.joint_nnet(x, t) | |
elif mode == 't2i': | |
return self.t2i_nnet(x, t, **kwargs) | |
elif mode == 'i2t': | |
return self.i2t_nnet(x, t, **kwargs) | |
elif mode == 'i': | |
return self.i_nnet(x, t) | |
elif mode == 't': | |
return self.t_nnet(x, t) | |
dpm_solver = DPM_Solver(model_fn, | |
noise_schedule, | |
predict_x0=True, | |
thresholding=False) | |
with torch.inference_mode(), torch.autocast( | |
device_type=self.device.type): | |
x = dpm_solver.sample(_x_init, | |
steps=self.config.sample.sample_steps, | |
eps=1. / self.N, | |
T=1.) | |
if mode == 'joint': | |
_z, _clip_img, _text = self.split_joint(x) | |
return _z, _clip_img, _text | |
elif mode in ['t2i', 'i']: | |
_z, _clip_img = self.split(x) | |
return _z, _clip_img | |
elif mode in ['i2t', 't']: | |
return x | |
return sample_fn | |
def to_pil(tensor: torch.Tensor) -> PIL.Image.Image: | |
return PIL.Image.fromarray( | |
tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to( | |
'cpu', torch.uint8).numpy()) | |
def run(self, mode: str, prompt: str, image_path: str, seed: int, | |
num_steps: int, | |
guidance_scale: float) -> tuple[PIL.Image.Image | None, str]: | |
self.config.mode = mode | |
self.config.prompt = prompt | |
self.config.img = image_path | |
self.config.seed = seed | |
self.config.sample.sample_steps = num_steps | |
self.config.sample.scale = guidance_scale | |
self.config.n_samples = 1 | |
#set_seed(self.config.seed) | |
if seed == -1: | |
seed = random.randint(0, 1000000) | |
torch.manual_seed(seed) | |
contexts, img_contexts, clip_imgs = self.prepare_contexts() | |
if self.use_caption_decoder: | |
contexts_low_dim = self.caption_decoder.encode_prefix(contexts) | |
else: | |
contexts_low_dim = contexts | |
z_img = self.autoencoder.sample(img_contexts) | |
if self.config.mode in ['t2i', 't2i2t']: | |
_n_samples = contexts_low_dim.size(0) | |
elif self.config.mode in ['i2t', 'i2t2i']: | |
_n_samples = img_contexts.size(0) | |
else: | |
_n_samples = self.config.n_samples | |
sample_fn = self.get_sample_fn(_n_samples) | |
if self.config.mode == 'joint': | |
_z, _clip_img, _text = sample_fn(self.config.mode) | |
samples = self.unpreprocess(self.decode(_z)) | |
samples = [self.to_pil(tensor) for tensor in samples] | |
prompts = self.caption_decoder.generate_captions(_text) | |
return samples[0], prompts[0] | |
elif self.config.mode in ['t2i', 'i', 'i2t2i']: | |
if self.config.mode == 't2i': | |
_z, _clip_img = sample_fn( | |
self.config.mode, | |
text=contexts_low_dim) # conditioned on the text embedding | |
elif self.config.mode == 'i': | |
_z, _clip_img = sample_fn(self.config.mode) | |
elif self.config.mode == 'i2t2i': | |
_text = sample_fn( | |
'i2t', z=z_img, | |
clip_img=clip_imgs) # conditioned on the image embedding | |
_z, _clip_img = sample_fn('t2i', text=_text) | |
samples = self.unpreprocess(self.decode(_z)) | |
samples = [self.to_pil(tensor) for tensor in samples] | |
return samples[0], '' | |
elif self.config.mode in ['i2t', 't', 't2i2t']: | |
if self.config.mode == 'i2t': | |
_text = sample_fn( | |
self.config.mode, z=z_img, | |
clip_img=clip_imgs) # conditioned on the image embedding | |
elif self.config.mode == 't': | |
_text = sample_fn(self.config.mode) | |
elif self.config.mode == 't2i2t': | |
_z, _clip_img = sample_fn('t2i', text=contexts_low_dim) | |
_text = sample_fn('i2t', z=_z, clip_img=_clip_img) | |
prompts = self.caption_decoder.generate_captions(_text) | |
return None, prompts[0] | |
else: | |
raise ValueError | |