|
<!--Copyright 2023 The HuggingFace Team. All rights reserved. |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
|
the License. You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
|
specific language governing permissions and limitations under the License. |
|
--> |
|
|
|
# DiffEdit |
|
|
|
[[open-in-colab]] |
|
|
|
μ΄λ―Έμ§ νΈμ§μ νλ €λ©΄ μΌλ°μ μΌλ‘ νΈμ§ν μμμ λ§μ€ν¬λ₯Ό μ 곡ν΄μΌ ν©λλ€. DiffEditλ ν
μ€νΈ 쿼리λ₯Ό κΈ°λ°μΌλ‘ λ§μ€ν¬λ₯Ό μλμΌλ‘ μμ±νλ―λ‘ μ΄λ―Έμ§ νΈμ§ μννΈμ¨μ΄ μμ΄λ λ§μ€ν¬λ₯Ό λ§λ€κΈ°κ° μ λ°μ μΌλ‘ λ μ¬μμ§λλ€. DiffEdit μκ³ λ¦¬μ¦μ μΈ λ¨κ³λ‘ μλν©λλ€: |
|
|
|
1. Diffusion λͺ¨λΈμ΄ μΌλΆ 쿼리 ν
μ€νΈμ μ°Έμ‘° ν
μ€νΈλ₯Ό 쑰건λΆλ‘ μ΄λ―Έμ§μ λ
Έμ΄μ¦λ₯Ό μ κ±°νμ¬ μ΄λ―Έμ§μ μ¬λ¬ μμμ λν΄ μλ‘ λ€λ₯Έ λ
Έμ΄μ¦ μΆμ μΉλ₯Ό μμ±νκ³ , κ·Έ μ°¨μ΄λ₯Ό μ¬μ©νμ¬ μΏΌλ¦¬ ν
μ€νΈμ μΌμΉνλλ‘ μ΄λ―Έμ§μ μ΄λ μμμ λ³κ²½ν΄μΌ νλμ§ μλ³νκΈ° μν λ§μ€ν¬λ₯Ό μΆλ‘ ν©λλ€. |
|
2. μ
λ ₯ μ΄λ―Έμ§κ° DDIMμ μ¬μ©νμ¬ μ μ¬ κ³΅κ°μΌλ‘ μΈμ½λ©λ©λλ€. |
|
3. λ§μ€ν¬ μΈλΆμ ν½μ
μ΄ μ
λ ₯ μ΄λ―Έμ§μ λμΌνκ² μ μ§λλλ‘ λ§μ€ν¬λ₯Ό κ°μ΄λλ‘ μ¬μ©νμ¬ ν
μ€νΈ 쿼리μ μ‘°κ±΄μ΄ μ§μ λ diffusion λͺ¨λΈλ‘ latentsλ₯Ό λμ½λ©ν©λλ€. |
|
|
|
μ΄ κ°μ΄λμμλ λ§μ€ν¬λ₯Ό μλμΌλ‘ λ§λ€μ§ μκ³ DiffEditλ₯Ό μ¬μ©νμ¬ μ΄λ―Έμ§λ₯Ό νΈμ§νλ λ°©λ²μ μ€λͺ
ν©λλ€. |
|
|
|
μμνκΈ° μ μ λ€μ λΌμ΄λΈλ¬λ¦¬κ° μ€μΉλμ΄ μλμ§ νμΈνμΈμ: |
|
|
|
```py |
|
# Colabμμ νμν λΌμ΄λΈλ¬λ¦¬λ₯Ό μ€μΉνκΈ° μν΄ μ£Όμμ μ μΈνμΈμ |
|
#!pip install -q diffusers transformers accelerate |
|
``` |
|
|
|
[`StableDiffusionDiffEditPipeline`]μλ μ΄λ―Έμ§ λ§μ€ν¬μ λΆλΆμ μΌλ‘ λ°μ λ latents μ§ν©μ΄ νμν©λλ€. μ΄λ―Έμ§ λ§μ€ν¬λ [`~StableDiffusionDiffEditPipeline.generate_mask`] ν¨μμμ μμ±λλ©°, λ κ°μ νλΌλ―Έν°μΈ `source_prompt`μ `target_prompt`κ° ν¬ν¨λ©λλ€. μ΄ λ§€κ°λ³μλ μ΄λ―Έμ§μμ 무μμ νΈμ§ν μ§ κ²°μ ν©λλ€. μλ₯Ό λ€μ΄, *κ³ΌμΌ* ν κ·Έλ¦μ *λ°°* ν κ·Έλ¦μΌλ‘ λ³κ²½νλ €λ©΄ λ€μκ³Ό κ°μ΄ νμΈμ: |
|
|
|
```py |
|
source_prompt = "a bowl of fruits" |
|
target_prompt = "a bowl of pears" |
|
``` |
|
|
|
λΆλΆμ μΌλ‘ λ°μ λ latentsλ [`~StableDiffusionDiffEditPipeline.invert`] ν¨μμμ μμ±λλ©°, μΌλ°μ μΌλ‘ μ΄λ―Έμ§λ₯Ό μ€λͺ
νλ `prompt` λλ *μΊ‘μ
*μ ν¬ν¨νλ κ²μ΄ inverse latent sampling νλ‘μΈμ€λ₯Ό κ°μ΄λνλ λ° λμμ΄ λ©λλ€. μΊ‘μ
μ μ’
μ’
`source_prompt`κ° λ μ μμ§λ§, λ€λ₯Έ ν
μ€νΈ μ€λͺ
μΌλ‘ μμ λ‘κ² μ€νν΄ λ³΄μΈμ! |
|
|
|
νμ΄νλΌμΈ, μ€μΌμ€λ¬, μ μ€μΌμ€λ¬λ₯Ό λΆλ¬μ€κ³ λ©λͺ¨λ¦¬ μ¬μ©λμ μ€μ΄κΈ° μν΄ λͺ κ°μ§ μ΅μ νλ₯Ό νμ±νν΄ λ³΄κ² μ΅λλ€: |
|
|
|
```py |
|
import torch |
|
from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionDiffEditPipeline |
|
|
|
pipeline = StableDiffusionDiffEditPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-1", |
|
torch_dtype=torch.float16, |
|
safety_checker=None, |
|
use_safetensors=True, |
|
) |
|
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) |
|
pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) |
|
pipeline.enable_model_cpu_offload() |
|
pipeline.enable_vae_slicing() |
|
``` |
|
|
|
μμ νκΈ° μν μ΄λ―Έμ§λ₯Ό λΆλ¬μ΅λλ€: |
|
|
|
```py |
|
from diffusers.utils import load_image, make_image_grid |
|
|
|
img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" |
|
raw_image = load_image(img_url).resize((768, 768)) |
|
raw_image |
|
``` |
|
|
|
μ΄λ―Έμ§ λ§μ€ν¬λ₯Ό μμ±νκΈ° μν΄ [`~StableDiffusionDiffEditPipeline.generate_mask`] ν¨μλ₯Ό μ¬μ©ν©λλ€. μ΄λ―Έμ§μμ νΈμ§ν λ΄μ©μ μ§μ νκΈ° μν΄ `source_prompt`μ `target_prompt`λ₯Ό μ λ¬ν΄μΌ ν©λλ€: |
|
|
|
```py |
|
from PIL import Image |
|
|
|
source_prompt = "a bowl of fruits" |
|
target_prompt = "a basket of pears" |
|
mask_image = pipeline.generate_mask( |
|
image=raw_image, |
|
source_prompt=source_prompt, |
|
target_prompt=target_prompt, |
|
) |
|
Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L").resize((768, 768)) |
|
``` |
|
|
|
λ€μμΌλ‘, λ°μ λ latentsλ₯Ό μμ±νκ³ μ΄λ―Έμ§λ₯Ό λ¬μ¬νλ μΊ‘μ
μ μ λ¬ν©λλ€: |
|
|
|
```py |
|
inv_latents = pipeline.invert(prompt=source_prompt, image=raw_image).latents |
|
``` |
|
|
|
λ§μ§λ§μΌλ‘, μ΄λ―Έμ§ λ§μ€ν¬μ λ°μ λ latentsλ₯Ό νμ΄νλΌμΈμ μ λ¬ν©λλ€. `target_prompt`λ μ΄μ `prompt`κ° λλ©°, `source_prompt`λ `negative_prompt`λ‘ μ¬μ©λ©λλ€. |
|
|
|
```py |
|
output_image = pipeline( |
|
prompt=target_prompt, |
|
mask_image=mask_image, |
|
image_latents=inv_latents, |
|
negative_prompt=source_prompt, |
|
).images[0] |
|
mask_image = Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L").resize((768, 768)) |
|
make_image_grid([raw_image, mask_image, output_image], rows=1, cols=3) |
|
``` |
|
|
|
<div class="flex gap-4"> |
|
<div> |
|
<img class="rounded-xl" src="https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"/> |
|
<figcaption class="mt-2 text-center text-sm text-gray-500">original image</figcaption> |
|
</div> |
|
<div> |
|
<img class="rounded-xl" src="https://github.com/Xiang-cd/DiffEdit-stable-diffusion/blob/main/assets/target.png?raw=true"/> |
|
<figcaption class="mt-2 text-center text-sm text-gray-500">edited image</figcaption> |
|
</div> |
|
</div> |
|
|
|
## Sourceμ target μλ² λ© μμ±νκΈ° |
|
|
|
Sourceμ target μλ² λ©μ μλμΌλ‘ μμ±νλ λμ [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) λͺ¨λΈμ μ¬μ©νμ¬ μλμΌλ‘ μμ±ν μ μμ΅λλ€. |
|
|
|
Flan-T5 λͺ¨λΈκ³Ό ν ν¬λμ΄μ λ₯Ό π€ Transformers λΌμ΄λΈλ¬λ¦¬μμ λΆλ¬μ΅λλ€: |
|
|
|
```py |
|
import torch |
|
from transformers import AutoTokenizer, T5ForConditionalGeneration |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") |
|
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto", torch_dtype=torch.float16) |
|
``` |
|
|
|
λͺ¨λΈμ ν둬ννΈν sourceμ target ν둬ννΈλ₯Ό μμ±νκΈ° μν΄ μ΄κΈ° ν
μ€νΈλ€μ μ 곡ν©λλ€. |
|
|
|
```py |
|
source_concept = "bowl" |
|
target_concept = "basket" |
|
|
|
source_text = f"Provide a caption for images containing a {source_concept}. " |
|
"The captions should be in English and should be no longer than 150 characters." |
|
|
|
target_text = f"Provide a caption for images containing a {target_concept}. " |
|
"The captions should be in English and should be no longer than 150 characters." |
|
``` |
|
|
|
λ€μμΌλ‘, ν둬ννΈλ€μ μμ±νκΈ° μν΄ μ νΈλ¦¬ν° ν¨μλ₯Ό μμ±ν©λλ€. |
|
|
|
```py |
|
@torch.no_grad() |
|
def generate_prompts(input_prompt): |
|
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to("cuda") |
|
|
|
outputs = model.generate( |
|
input_ids, temperature=0.8, num_return_sequences=16, do_sample=True, max_new_tokens=128, top_k=10 |
|
) |
|
return tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
source_prompts = generate_prompts(source_text) |
|
target_prompts = generate_prompts(target_text) |
|
print(source_prompts) |
|
print(target_prompts) |
|
``` |
|
|
|
<Tip> |
|
|
|
λ€μν νμ§μ ν
μ€νΈλ₯Ό μμ±νλ μ λ΅μ λν΄ μμΈν μμλ³΄λ €λ©΄ [μμ± μ λ΅](https://huggingface.co/docs/transformers/main/en/generation_strategies) κ°μ΄λλ₯Ό μ°Έμ‘°νμΈμ. |
|
|
|
</Tip> |
|
|
|
ν
μ€νΈ μΈμ½λ©μ μν΄ [`StableDiffusionDiffEditPipeline`]μμ μ¬μ©νλ ν
μ€νΈ μΈμ½λ λͺ¨λΈμ λΆλ¬μ΅λλ€. ν
μ€νΈ μΈμ½λλ₯Ό μ¬μ©νμ¬ ν
μ€νΈ μλ² λ©μ κ³μ°ν©λλ€: |
|
|
|
```py |
|
import torch |
|
from diffusers import StableDiffusionDiffEditPipeline |
|
|
|
pipeline = StableDiffusionDiffEditPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, use_safetensors=True |
|
) |
|
pipeline.enable_model_cpu_offload() |
|
pipeline.enable_vae_slicing() |
|
|
|
@torch.no_grad() |
|
def embed_prompts(sentences, tokenizer, text_encoder, device="cuda"): |
|
embeddings = [] |
|
for sent in sentences: |
|
text_inputs = tokenizer( |
|
sent, |
|
padding="max_length", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0] |
|
embeddings.append(prompt_embeds) |
|
return torch.concatenate(embeddings, dim=0).mean(dim=0).unsqueeze(0) |
|
|
|
source_embeds = embed_prompts(source_prompts, pipeline.tokenizer, pipeline.text_encoder) |
|
target_embeds = embed_prompts(target_prompts, pipeline.tokenizer, pipeline.text_encoder) |
|
``` |
|
|
|
λ§μ§λ§μΌλ‘, μλ² λ©μ [`~StableDiffusionDiffEditPipeline.generate_mask`] λ° [`~StableDiffusionDiffEditPipeline.invert`] ν¨μμ νμ΄νλΌμΈμ μ λ¬νμ¬ μ΄λ―Έμ§λ₯Ό μμ±ν©λλ€: |
|
|
|
```diff |
|
from diffusers import DDIMInverseScheduler, DDIMScheduler |
|
from diffusers.utils import load_image, make_image_grid |
|
from PIL import Image |
|
|
|
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) |
|
pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) |
|
|
|
img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" |
|
raw_image = load_image(img_url).resize((768, 768)) |
|
|
|
mask_image = pipeline.generate_mask( |
|
image=raw_image, |
|
- source_prompt=source_prompt, |
|
- target_prompt=target_prompt, |
|
+ source_prompt_embeds=source_embeds, |
|
+ target_prompt_embeds=target_embeds, |
|
) |
|
|
|
inv_latents = pipeline.invert( |
|
- prompt=source_prompt, |
|
+ prompt_embeds=source_embeds, |
|
image=raw_image, |
|
).latents |
|
|
|
output_image = pipeline( |
|
mask_image=mask_image, |
|
image_latents=inv_latents, |
|
- prompt=target_prompt, |
|
- negative_prompt=source_prompt, |
|
+ prompt_embeds=target_embeds, |
|
+ negative_prompt_embeds=source_embeds, |
|
).images[0] |
|
mask_image = Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L") |
|
make_image_grid([raw_image, mask_image, output_image], rows=1, cols=3) |
|
``` |
|
|
|
## λ°μ μ μν μΊ‘μ
μμ±νκΈ° |
|
|
|
`source_prompt`λ₯Ό μΊ‘μ
μΌλ‘ μ¬μ©νμ¬ λΆλΆμ μΌλ‘ λ°μ λ latentsλ₯Ό μμ±ν μ μμ§λ§, [BLIP](https://huggingface.co/docs/transformers/model_doc/blip) λͺ¨λΈμ μ¬μ©νμ¬ μΊ‘μ
μ μλμΌλ‘ μμ±ν μλ μμ΅λλ€. |
|
|
|
π€ Transformers λΌμ΄λΈλ¬λ¦¬μμ BLIP λͺ¨λΈκ³Ό νλ‘μΈμλ₯Ό λΆλ¬μ΅λλ€: |
|
|
|
```py |
|
import torch |
|
from transformers import BlipForConditionalGeneration, BlipProcessor |
|
|
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16, low_cpu_mem_usage=True) |
|
``` |
|
|
|
μ
λ ₯ μ΄λ―Έμ§μμ μΊ‘μ
μ μμ±νλ μ νΈλ¦¬ν° ν¨μλ₯Ό λ§λλλ€: |
|
|
|
```py |
|
@torch.no_grad() |
|
def generate_caption(images, caption_generator, caption_processor): |
|
text = "a photograph of" |
|
|
|
inputs = caption_processor(images, text, return_tensors="pt").to(device="cuda", dtype=caption_generator.dtype) |
|
caption_generator.to("cuda") |
|
outputs = caption_generator.generate(**inputs, max_new_tokens=128) |
|
|
|
# μΊ‘μ
generator μ€νλ‘λ |
|
caption_generator.to("cpu") |
|
|
|
caption = caption_processor.batch_decode(outputs, skip_special_tokens=True)[0] |
|
return caption |
|
``` |
|
|
|
μ
λ ₯ μ΄λ―Έμ§λ₯Ό λΆλ¬μ€κ³ `generate_caption` ν¨μλ₯Ό μ¬μ©νμ¬ ν΄λΉ μ΄λ―Έμ§μ λν μΊ‘μ
μ μμ±ν©λλ€: |
|
|
|
```py |
|
from diffusers.utils import load_image |
|
|
|
img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" |
|
raw_image = load_image(img_url).resize((768, 768)) |
|
caption = generate_caption(raw_image, model, processor) |
|
``` |
|
|
|
<div class="flex justify-center"> |
|
<figure> |
|
<img class="rounded-xl" src="https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"/> |
|
<figcaption class="text-center">generated caption: "a photograph of a bowl of fruit on a table"</figcaption> |
|
</figure> |
|
</div> |
|
|
|
μ΄μ μΊ‘μ
μ [`~StableDiffusionDiffEditPipeline.invert`] ν¨μμ λμ λΆλΆμ μΌλ‘ λ°μ λ latentsλ₯Ό μμ±ν μ μμ΅λλ€! |
|
|