Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import re | |
import time | |
from dataclasses import dataclass | |
from glob import iglob | |
from einops import rearrange | |
from PIL import ExifTags, Image | |
import torch | |
import gradio as gr | |
import numpy as np | |
from flux.sampling import prepare | |
from flux.util import (load_ae, load_clip, load_t5) | |
from models.kv_edit import Flux_kv_edit,Flux_kv_edit_inf | |
import spaces | |
from huggingface_hub import login | |
login(token=os.getenv('Token')) | |
class SamplingOptions: | |
source_prompt: str = '' | |
target_prompt: str = '' | |
# prompt: str | |
width: int = 1366 | |
height: int = 768 | |
inversion_num_steps: int = 0 | |
denoise_num_steps: int = 0 | |
skip_step: int = 0 | |
inversion_guidance: float = 1.0 | |
denoise_guidance: float = 1.0 | |
seed: int = 42 | |
re_init: bool = False | |
attn_mask: bool = False | |
def encode(init_image, torch_device): | |
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1 | |
init_image = init_image.unsqueeze(0) | |
init_image = init_image.to(torch_device) | |
with torch.no_grad(): | |
init_image = ae.encode(init_image.to()).to(torch.bfloat16) | |
return init_image | |
# init all components | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
name = 'flux-dev' | |
ae = load_ae(name, device) | |
t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512) | |
clip = load_clip(device) | |
model = Flux_kv_edit(device=device, name=name) | |
offload = False | |
name = "flux-dev" | |
is_schnell = False | |
feature_path = 'feature' | |
output_dir = 'result' | |
add_sampling_metadata = True | |
def edit(brush_canvas, | |
source_prompt, target_prompt, | |
inversion_num_steps, denoise_num_steps, | |
skip_step, | |
inversion_guidance, denoise_guidance,seed, | |
re_init,attn_mask | |
): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch.cuda.empty_cache() | |
rgba_init_image = brush_canvas["background"] | |
init_image = rgba_init_image[:,:,:3] | |
shape = init_image.shape | |
height = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16 | |
width = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16 | |
init_image = init_image[:height, :width, :] | |
rgba_init_image = rgba_init_image[:height, :width, :] | |
rgba_mask = brush_canvas["layers"][0][:height, :width, :] | |
mask = rgba_mask[:,:,3]/255 | |
mask = mask.astype(int) | |
rgba_mask[:,:,3] = rgba_mask[:,:,3]//2 | |
masked_image = Image.alpha_composite(Image.fromarray(rgba_init_image, 'RGBA'), Image.fromarray(rgba_mask, 'RGBA')) | |
mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(torch.bfloat16).to(device) | |
init_image = encode(init_image, device).to(device) | |
seed = int(seed) | |
if seed == -1: | |
seed = torch.randint(0, 2**32, (1,)).item() | |
opts = SamplingOptions( | |
source_prompt=source_prompt, | |
target_prompt=target_prompt, | |
width=width, | |
height=height, | |
inversion_num_steps=inversion_num_steps, | |
denoise_num_steps=denoise_num_steps, | |
skip_step=skip_step, | |
inversion_guidance=inversion_guidance, | |
denoise_guidance=denoise_guidance, | |
seed=seed, | |
re_init=re_init, | |
attn_mask=attn_mask | |
) | |
torch.manual_seed(opts.seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(opts.seed) | |
t0 = time.perf_counter() | |
#############inverse####################### | |
# 将布尔数组转换为整数类型,如果需要1和0而不是True和False的话 | |
with torch.no_grad(): | |
inp = prepare(t5, clip, init_image, prompt=opts.source_prompt) | |
inp_target = prepare(t5, clip, init_image, prompt=opts.target_prompt) | |
x = model(inp, inp_target, mask, opts) | |
device = torch.device("cuda") | |
with torch.autocast(device_type=device.type, dtype=torch.bfloat16): | |
x = ae.decode(x) | |
# 得到还在显卡上的特征 | |
# bring into PIL format and save | |
x = x.clamp(-1, 1) | |
# x = embed_watermark(x.float()) | |
x = x.float().cpu() | |
x = rearrange(x[0], "c h w -> h w c") | |
if torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
#############回到像素空间就算结束####################### | |
output_name = os.path.join(output_dir, "img_{idx}.jpg") | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
idx = 0 | |
else: | |
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] | |
if len(fns) > 0: | |
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 | |
else: | |
idx = 0 | |
#############找idx####################### | |
fn = output_name.format(idx=idx) | |
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) | |
exif_data = Image.Exif() | |
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" | |
exif_data[ExifTags.Base.Make] = "Black Forest Labs" | |
exif_data[ExifTags.Base.Model] = name | |
exif_data[ExifTags.Base.ImageDescription] = target_prompt | |
img.save(fn, exif=exif_data, quality=95, subsampling=0) | |
masked_image.save(fn.replace(".jpg", "_mask.png"), format='PNG') | |
t1 = time.perf_counter() | |
print(f"Done in {t1 - t0:.1f}s. Saving {fn}") | |
print("End Edit") | |
return img | |
def create_demo(model_name: str): | |
# editor = FluxEditor_kv_demo() | |
is_schnell = model_name == "flux-schnell" | |
title = r""" | |
<h1 align="center">🎨 KV-Edit: Training-Free Image Editing for Precise Background Preservation</h1> | |
""" | |
description = r""" | |
<b>Official 🤗 Gradio demo</b> for <a href='https://github.com/Xilluill/KV-Edit' target='_blank'><b>KV-Edit: Training-Free Image Editing for Precise Background Preservation</b></a>.<br> | |
💫💫 <b>Here is editing steps:</b> (We highly recommend you run our code locally!😘 Only one inversion before multiple editing, very productive!) <br> | |
1️⃣ Upload your image that needs to be edited (The resolution must be less than 1360*768 because of memory.🙂) <br> | |
2️⃣ Fill in your source prompt and use the brush tool to draw your mask area. <br> | |
3️⃣ Fill in your target prompt, then adjust the hyperparameters. <br> | |
4️⃣ Click the "Edit" button to generate your edited image! <br> | |
🔔🔔 [<b>Important</b>] We suggest trying "re_init" and "attn_mask" only when the result is too similar to the original content (e.g. removing objects).<br> | |
""" | |
article = r""" | |
If our work is helpful, please help to ⭐ the <a href='https://github.com/Xilluill/KV-Edit' target='_blank'>Github Repo</a>. Thanks! | |
""" | |
badge = r""" | |
[](https://github.com/Xilluill/KV-Edit) | |
""" | |
with gr.Blocks() as demo: | |
gr.HTML(title) | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
source_prompt = gr.Textbox(label="Source Prompt", value='' ) | |
inversion_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of inversion steps") | |
target_prompt = gr.Textbox(label="Target Prompt", value='' ) | |
denoise_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of denoise steps") | |
brush_canvas = gr.ImageEditor(label="Brush Canvas", | |
sources=('upload'), | |
brush=gr.Brush(colors=["#ff0000"],color_mode='fixed'), | |
interactive=True, | |
transforms=[], | |
container=True, | |
format='png',scale=1) | |
edit_btn = gr.Button("edit") | |
with gr.Column(): | |
with gr.Accordion("Advanced Options", open=True): | |
skip_step = gr.Slider(0, 30, 4, step=1, label="Number of skip steps") | |
inversion_guidance = gr.Slider(1.0, 10.0, 1.5, step=0.1, label="inversion Guidance", interactive=not is_schnell) | |
denoise_guidance = gr.Slider(1.0, 10.0, 5.5, step=0.1, label="denoise Guidance", interactive=not is_schnell) | |
seed = gr.Textbox('0', label="Seed (-1 for random)", visible=True) | |
with gr.Row(): | |
re_init = gr.Checkbox(label="re_init", value=False) | |
attn_mask = gr.Checkbox(label="attn_mask", value=False) | |
output_image = gr.Image(label="Generated Image") | |
gr.Markdown(article) | |
edit_btn.click( | |
fn=edit, | |
inputs=[brush_canvas, | |
source_prompt, target_prompt, | |
inversion_num_steps, denoise_num_steps, | |
skip_step, | |
inversion_guidance, | |
denoise_guidance,seed, | |
re_init,attn_mask | |
], | |
outputs=[output_image] | |
) | |
return demo | |
demo = create_demo("flux-dev") | |
demo.launch() |