Spaces:
Runtime error
Runtime error
| import math | |
| from typing import Callable | |
| import torch | |
| from einops import rearrange, repeat | |
| from torch import Tensor | |
| from .model import Flux,Flux_kv | |
| from .modules.conditioner import HFEmbedder | |
| from tqdm import tqdm | |
| from tqdm.contrib import tzip | |
| def get_noise( | |
| num_samples: int, | |
| height: int, | |
| width: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| seed: int, | |
| ): | |
| return torch.randn( | |
| num_samples, | |
| 16, | |
| # allow for packing | |
| 2 * math.ceil(height / 16), | |
| 2 * math.ceil(width / 16), | |
| device=device, | |
| dtype=dtype, | |
| generator=torch.Generator(device=device).manual_seed(seed), | |
| ) | |
| def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: | |
| bs, c, h, w = img.shape | |
| if bs == 1 and not isinstance(prompt, str): | |
| bs = len(prompt) | |
| img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
| if img.shape[0] == 1 and bs > 1: | |
| img = repeat(img, "1 ... -> bs ...", bs=bs) | |
| img_ids = torch.zeros(h // 2, w // 2, 3) | |
| img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] | |
| img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] | |
| img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) | |
| if isinstance(prompt, str): | |
| prompt = [prompt] | |
| txt = t5(prompt) | |
| if txt.shape[0] == 1 and bs > 1: | |
| txt = repeat(txt, "1 ... -> bs ...", bs=bs) | |
| txt_ids = torch.zeros(bs, txt.shape[1], 3) | |
| vec = clip(prompt) | |
| if vec.shape[0] == 1 and bs > 1: | |
| vec = repeat(vec, "1 ... -> bs ...", bs=bs) | |
| return { | |
| "img": img, | |
| "img_ids": img_ids.to(img.device), | |
| "txt": txt.to(img.device), | |
| "txt_ids": txt_ids.to(img.device), | |
| "vec": vec.to(img.device), | |
| } | |
| def prepare_flowedit(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, source_prompt: str | list[str],target_prompt) -> dict[str, Tensor]: | |
| bs, c, h, w = img.shape | |
| if bs == 1 and not isinstance(source_prompt, str): | |
| bs = len(source_prompt) | |
| img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
| if img.shape[0] == 1 and bs > 1: | |
| img = repeat(img, "1 ... -> bs ...", bs=bs) | |
| img_ids = torch.zeros(h // 2, w // 2, 3) | |
| img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] | |
| img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] | |
| img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) | |
| # if isinstance(prompt, str): | |
| # prompt = [prompt] | |
| # txt = t5(prompt) | |
| # if txt.shape[0] == 1 and bs > 1: | |
| # txt = repeat(txt, "1 ... -> bs ...", bs=bs) | |
| # txt_ids = torch.zeros(bs, txt.shape[1], 3) | |
| # vec = clip(prompt) | |
| # if vec.shape[0] == 1 and bs > 1: | |
| # vec = repeat(vec, "1 ... -> bs ...", bs=bs) | |
| if isinstance(source_prompt, str): | |
| source_prompt = [source_prompt] | |
| source_txt = t5(source_prompt) | |
| if source_txt.shape[0] == 1 and bs > 1: | |
| source_txt = repeat(source_txt, "1 ... -> bs ...", bs=bs) | |
| source_txt_ids = torch.zeros(bs, source_txt.shape[1], 3) | |
| source_vec = clip(target_prompt) | |
| if source_vec.shape[0] == 1 and bs > 1: | |
| source_vec = repeat(source_vec, "1 ... -> bs ...", bs=bs) | |
| if isinstance(target_prompt, str): | |
| target_prompt = [target_prompt] | |
| target_txt = t5(target_prompt) | |
| if target_txt.shape[0] == 1 and bs > 1: | |
| target_txt = repeat(target_txt, "1 ... -> bs ...", bs=bs) | |
| target_txt_ids = torch.zeros(bs, target_txt.shape[1], 3) | |
| target_vec = clip(target_prompt) | |
| if target_vec.shape[0] == 1 and bs > 1: | |
| target_vec = repeat(target_vec, "1 ... -> bs ...", bs=bs) | |
| return { | |
| "img": img, | |
| "img_ids": img_ids.to(img.device), | |
| "source_txt": source_txt.to(img.device), | |
| "source_txt_ids": source_txt_ids.to(img.device), | |
| "source_vec": source_vec.to(img.device), | |
| "target_txt": target_txt.to(img.device), | |
| "target_txt_ids": target_txt_ids.to(img.device), | |
| "target_vec": target_vec.to(img.device) | |
| } | |
| def time_shift(mu: float, sigma: float, t: Tensor): | |
| return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) | |
| def get_lin_function( | |
| x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 | |
| ) -> Callable[[float], float]: | |
| m = (y2 - y1) / (x2 - x1) | |
| b = y1 - m * x1 | |
| return lambda x: m * x + b | |
| def get_schedule( | |
| num_steps: int, | |
| image_seq_len: int, | |
| base_shift: float = 0.5, | |
| max_shift: float = 1.15, | |
| shift: bool = True, | |
| ) -> list[float]: | |
| # extra step for zero | |
| timesteps = torch.linspace(1, 0, num_steps + 1) | |
| # shifting the schedule to favor high timesteps for higher signal images | |
| if shift: | |
| # estimate mu based on linear estimation between two points | |
| mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) | |
| timesteps = time_shift(mu, 1.0, timesteps) | |
| return timesteps.tolist() | |
| def denoise( | |
| model: Flux, | |
| # model input | |
| img: Tensor, | |
| img_ids: Tensor, | |
| txt: Tensor, | |
| txt_ids: Tensor, | |
| vec: Tensor, | |
| # sampling parameters | |
| timesteps: list[float], | |
| guidance: float = 4.0, | |
| ): | |
| # this is ignored for schnell | |
| guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) | |
| for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): | |
| t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) | |
| pred = model( | |
| img=img, | |
| img_ids=img_ids, | |
| txt=txt, | |
| txt_ids=txt_ids, | |
| y=vec, | |
| timesteps=t_vec, | |
| guidance=guidance_vec, | |
| ) | |
| img = img + (t_prev - t_curr) * pred | |
| return img | |
| def unpack(x: Tensor, height: int, width: int) -> Tensor: | |
| return rearrange( | |
| x, | |
| "b (h w) (c ph pw) -> b c (h ph) (w pw)", | |
| h=math.ceil(height / 16), | |
| w=math.ceil(width / 16), | |
| ph=2, | |
| pw=2, | |
| ) | |
| def denoise_kv( | |
| model: Flux_kv, | |
| # model input | |
| img: Tensor, | |
| img_ids: Tensor, | |
| txt: Tensor, | |
| txt_ids: Tensor, | |
| vec: Tensor, | |
| # sampling parameters | |
| timesteps: list[float], | |
| inverse, | |
| info, | |
| guidance: float = 4.0 | |
| ): | |
| if inverse: | |
| timesteps = timesteps[::-1] | |
| guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) | |
| for i, (t_curr, t_prev) in enumerate(tzip(timesteps[:-1], timesteps[1:])): | |
| t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) | |
| info['t'] = t_prev if inverse else t_curr | |
| if inverse: | |
| img_name = str(info['t']) + '_' + 'img' | |
| info['feature'][img_name] = img.cpu() | |
| else: | |
| img_name = str(info['t']) + '_' + 'img' | |
| source_img = info['feature'][img_name].to(img.device) | |
| img = source_img[:, info['mask_indices'],...] * (1 - info['mask'][:, info['mask_indices'],...]) + img * info['mask'][:, info['mask_indices'],...] | |
| pred = model( | |
| img=img, | |
| img_ids=img_ids, | |
| txt=txt, | |
| txt_ids=txt_ids, | |
| y=vec, | |
| timesteps=t_vec, | |
| guidance=guidance_vec, | |
| info=info | |
| ) | |
| img = img + (t_prev - t_curr) * pred | |
| return img, info | |
| def denoise_kv_inf( | |
| model: Flux_kv, | |
| # model input | |
| img: Tensor, | |
| img_ids: Tensor, | |
| source_txt: Tensor, | |
| source_txt_ids: Tensor, | |
| source_vec: Tensor, | |
| target_txt: Tensor, | |
| target_txt_ids: Tensor, | |
| target_vec: Tensor, | |
| # sampling parameters | |
| timesteps: list[float], | |
| target_guidance: float = 4.0, | |
| source_guidance: float = 4.0, | |
| info: dict = {}, | |
| ): | |
| target_guidance_vec = torch.full((img.shape[0],), target_guidance, device=img.device, dtype=img.dtype) | |
| source_guidance_vec = torch.full((img.shape[0],), source_guidance, device=img.device, dtype=img.dtype) | |
| mask_indices = info['mask_indices'] | |
| init_img = img.clone() # torch.Size([1, 4080, 64]) | |
| z_fe = img[:, mask_indices,...] | |
| noise_list = [] | |
| for i in range(len(timesteps)): | |
| noise = torch.randn(init_img.size(), dtype=init_img.dtype, | |
| layout=init_img.layout, device=init_img.device, | |
| generator=torch.Generator(device=init_img.device).manual_seed(0)) # 每次重新取噪声 根据t进行加噪 | |
| noise_list.append(noise) | |
| for i, (t_curr, t_prev) in enumerate(tzip(timesteps[:-1], timesteps[1:])): # 从高到低 | |
| info['t'] = 'inf' | |
| t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) | |
| z_src = (1 - t_curr) * init_img + t_curr * noise_list[i] | |
| z_tar = z_src[:, mask_indices,...] - init_img[:, mask_indices,...] + z_fe | |
| info['inverse'] = True | |
| info['feature'] = {} # 清空kv特征 | |
| v_src = model( | |
| img=z_src, | |
| img_ids=img_ids, | |
| txt=source_txt, | |
| txt_ids=source_txt_ids, | |
| y=source_vec, | |
| timesteps=t_vec, | |
| guidance=source_guidance_vec, | |
| info=info | |
| ) | |
| info['inverse'] = False | |
| v_tar = model( | |
| img=z_tar, | |
| img_ids=img_ids, | |
| txt=target_txt, | |
| txt_ids=target_txt_ids, | |
| y=target_vec, | |
| timesteps=t_vec, | |
| guidance=target_guidance_vec, | |
| info=info | |
| ) | |
| v_fe = v_tar - v_src[:, mask_indices,...] | |
| z_fe = z_fe + (t_prev - t_curr) * v_fe * info['mask'][:, mask_indices,...] | |
| return z_fe, info | |