Spaces:
Running
on
Zero
Running
on
Zero
from dataclasses import dataclass | |
from einops import rearrange,repeat | |
import torch | |
import torch.nn.functional as F | |
from torch import Tensor | |
from typing import List | |
from flux.sampling import get_schedule, unpack,denoise_kv,denoise_kv_inf | |
from flux.util import load_flow_model | |
from flux.model import Flux_kv | |
class SamplingOptions: | |
source_prompt: str = '' | |
target_prompt: str = '' | |
# prompt: str | |
width: int = 1366 | |
height: int = 768 | |
inversion_num_steps: int = 0 | |
denoise_num_steps: int = 0 | |
skip_step: int = 0 | |
inversion_guidance: float = 1.0 | |
denoise_guidance: float = 1.0 | |
seed: int = 42 | |
re_init: bool = False | |
attn_mask: bool = False | |
class only_Flux(torch.nn.Module): # 仅包括初始化函数 | |
def __init__(self, device,name='flux-dev'): | |
self.device = device | |
self.name = name | |
super().__init__() | |
self.model = load_flow_model(self.name, device=self.device,flux_cls=Flux_kv) | |
def create_attention_mask(self,seq_len, mask_indices, text_len=512, device='cuda'): | |
""" | |
创建自定义的注意力掩码。 | |
Args: | |
seq_len (int): 序列长度。 | |
mask_indices (List[int]): 图像令牌中掩码区域的索引。 | |
text_len (int): 文本令牌的长度,默认 512。 | |
device (str): 设备类型,如 'cuda' 或 'cpu'。 | |
Returns: | |
torch.Tensor: 形状为 (seq_len, seq_len) 的注意力掩码。 | |
""" | |
# 初始化掩码为全 False | |
attention_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=device) | |
# 文本令牌索引 | |
text_indices = torch.arange(0, text_len, device=device) | |
# 掩码区域令牌索引 | |
mask_token_indices = torch.tensor([idx + text_len for idx in mask_indices], device=device) | |
# 背景区域令牌索引 | |
all_indices = torch.arange(text_len, seq_len, device=device) | |
background_token_indices = torch.tensor([idx for idx in all_indices if idx not in mask_token_indices]) | |
# 设置文本查询可以关注所有键 | |
attention_mask[text_indices.unsqueeze(1).expand(-1, seq_len)] = True | |
attention_mask[text_indices.unsqueeze(1), text_indices] = True# 关注文本 | |
attention_mask[text_indices.unsqueeze(1), background_token_indices] = True # 关注背景 | |
# attention_mask[mask_token_indices.unsqueeze(1), background_token_indices] = True # 关注背景 | |
attention_mask[mask_token_indices.unsqueeze(1), text_indices] = True # 关注文本 | |
attention_mask[mask_token_indices.unsqueeze(1), mask_token_indices] = True # 关注掩码区域 | |
# attention_mask[background_token_indices.unsqueeze(1).expand(-1, seq_len), :] = False | |
# attention_mask[background_token_indices.unsqueeze(1), mask_token_indices] = True # 关注掩码 | |
attention_mask[background_token_indices.unsqueeze(1), text_indices] = True # 关注文本 | |
attention_mask[background_token_indices.unsqueeze(1), background_token_indices] = True # 关注背景区域 | |
return attention_mask.unsqueeze(0) | |
class Flux_kv_edit_inf(only_Flux): | |
def __init__(self, device,name): | |
super().__init__(device,name) | |
def forward(self,inp,inp_target,mask:Tensor,opts): | |
#############根据mask生成token序列上的索引试试####################### | |
info = {} | |
info['feature'] = {} | |
bs, L, d = inp["img"].shape | |
h = opts.height // 8 | |
w = opts.width // 8 | |
mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False) | |
mask[mask > 0] = 1 | |
mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16) | |
# mask = F.max_pool2d(mask, kernel_size=3, stride=1, padding=1) | |
# mask = mask.flatten().to(self.device[1]) | |
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
info['mask'] = mask | |
bool_mask = (mask.sum(dim=2) > 0.5) | |
info['mask_indices'] = torch.nonzero(bool_mask)[:,1] # 使用花式索引 即 数字tensor索引tensor 这个是基于图像的 在seq中需要加512 | |
#单独分离inversion | |
if opts.attn_mask and (~bool_mask).any(): # mask有一个false就进行attn mask 全true就none | |
attention_mask = self.create_attention_mask(L+512, info['mask_indices'], device=self.device) | |
else: | |
attention_mask = None | |
info['attention_mask'] = attention_mask | |
denoise_timesteps = get_schedule(opts.denoise_num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")) | |
# denoise_timesteps = get_schedule(opts.denoise_num_steps, inp_target["img"].shape[1], shift=False) | |
denoise_timesteps = denoise_timesteps[opts.skip_step:] | |
z0 = inp["img"] | |
with torch.no_grad(): | |
info['inject'] = True | |
z_fe, info = denoise_kv_inf(self.model, img=inp["img"], img_ids=inp['img_ids'], | |
source_txt=inp['txt'], source_txt_ids=inp['txt_ids'], source_vec=inp['vec'], | |
target_txt=inp_target['txt'], target_txt_ids=inp_target['txt_ids'], target_vec=inp_target['vec'], | |
timesteps=denoise_timesteps, source_guidance=opts.inversion_guidance, target_guidance=opts.denoise_guidance, | |
info=info) | |
mask_indices = info['mask_indices'] # 图片seq坐标下的 | |
# x是根据索引取出来的 再放回去 | |
z0[:, mask_indices,...] = z_fe | |
# decode latents to pixel space | |
z0 = unpack(z0.float(), opts.height, opts.width) | |
del info | |
return z0 | |
class Flux_kv_edit(only_Flux): | |
def __init__(self, device,name): | |
super().__init__(device,name) | |
def forward(self,inp,inp_target,mask:Tensor,opts): | |
z0,zt,info = self.inverse(inp,mask,opts) | |
z0 = self.denoise(z0,zt,inp_target,mask,opts,info) | |
return z0 | |
def inverse(self,inp,mask,opts): | |
info = {} | |
info['feature'] = {} | |
bs, L, d = inp["img"].shape | |
h = opts.height // 8 | |
w = opts.width // 8 | |
# mask = F.interpolate(mask, size=(h,w), mode='nearest') | |
if opts.attn_mask: | |
mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False) | |
mask[mask > 0] = 1 | |
mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16) | |
# mask = F.max_pool2d(mask, kernel_size=3, stride=1, padding=1) | |
# mask = mask.flatten().to(self.device[1]) | |
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
bool_mask = (mask.sum(dim=2) > 0.5) | |
mask_indices = torch.nonzero(bool_mask)[:,1] # 使用花式索引 即 数字tensor索引tensor 这个是基于图像的 在seq中需要加512 | |
#单独分离inversion | |
assert not (~bool_mask).all(), "mask is all false" | |
assert not (bool_mask).all(), "mask is all true" | |
attention_mask = self.create_attention_mask(L+512, mask_indices, device=mask.device) | |
info['attention_mask'] = attention_mask | |
denoise_timesteps = get_schedule(opts.denoise_num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")) | |
denoise_timesteps = denoise_timesteps[opts.skip_step:] | |
# 加噪过程 | |
z0 = inp["img"].clone() | |
info['inverse'] = True | |
zt, info = denoise_kv(self.model, **inp, timesteps=denoise_timesteps, guidance=opts.inversion_guidance, inverse=True, info=info) | |
return z0,zt,info | |
def denoise(self,z0,zt,inp_target,mask:Tensor,opts,info): | |
h = opts.height // 8 | |
w = opts.width // 8 | |
mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False) | |
mask[mask > 0] = 1 | |
mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16) | |
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
info['mask'] = mask | |
bool_mask = (mask.sum(dim=2) > 0.5) | |
info['mask_indices'] = torch.nonzero(bool_mask)[:,1] # 使用花式索引 即 数字tensor索引tensor 这个是基于图像的 在seq中需要加512 | |
denoise_timesteps = get_schedule(opts.denoise_num_steps, inp_target["img"].shape[1], shift=(self.name != "flux-schnell")) | |
denoise_timesteps = denoise_timesteps[opts.skip_step:] | |
# 重建的时候不需要全部token z这里需要根据indice拿出来 | |
mask_indices = info['mask_indices'] # 图片seq坐标下的 | |
if opts.re_init: | |
noise = torch.randn_like(zt) | |
t = denoise_timesteps[0] | |
zt_noise = z0 *(1 - t) + noise * t | |
inp_target["img"] = zt_noise[:, mask_indices,...] | |
else: | |
inp_target["img"] = zt[:, mask_indices,...] | |
info['inverse'] = False | |
x, _ = denoise_kv(self.model, **inp_target, timesteps=denoise_timesteps, guidance=opts.denoise_guidance, inverse=False, info=info) | |
# x是根据索引取出来的 再放回去 | |
z0[:, mask_indices,...] = z0[:, mask_indices,...] * (1 - info['mask'][:, mask_indices,...]) + x * info['mask'][:, mask_indices,...] | |
# x = inp['img'].clone() | |
# decode latents to pixel space | |
z0 = unpack(z0.float(), opts.height, opts.width) | |
del info | |
return z0 |