Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
from diffusers import FluxKontextPipeline | |
from diffusers.utils import load_image | |
from PIL import Image | |
import os | |
# ---------------------------------------------- | |
# Style โ LoRA file mapping | |
# ---------------------------------------------- | |
STYLE_LORA_MAP = { | |
"3D_Chibi": "3D_Chibi_lora_weights.safetensors", | |
"American_Cartoon": "American_Cartoon_lora_weights.safetensors", | |
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors", | |
"Clay_Toy": "Clay_Toy_lora_weights.safetensors", | |
"Fabric": "Fabric_lora_weights.safetensors", | |
"Ghibli": "Ghibli_lora_weights.safetensors", | |
"Irasutoya": "Irasutoya_lora_weights.safetensors", | |
"Jojo": "Jojo_lora_weights.safetensors", | |
"Oil_Painting": "Oil_Painting_lora_weights.safetensors", | |
"Pixel": "Pixel_lora_weights.safetensors", | |
"Snoopy": "Snoopy_lora_weights.safetensors", | |
"Poly": "Poly_lora_weights.safetensors", | |
"LEGO": "LEGO_lora_weights.safetensors", | |
"Origami": "Origami_lora_weights.safetensors", | |
"Pop_Art": "Pop_Art_lora_weights.safetensors", | |
"Van_Gogh": "Van_Gogh_lora_weights.safetensors", | |
"Paper_Cutting": "Paper_Cutting_lora_weights.safetensors", | |
"Line": "Line_lora_weights.safetensors", | |
"Vector": "Vector_lora_weights.safetensors", | |
"Picasso": "Picasso_lora_weights.safetensors", | |
"Macaron": "Macaron_lora_weights.safetensors", | |
"Rick_Morty": "Rick_Morty_lora_weights.safetensors" | |
} | |
# ---------------------------------------------- | |
# Style descriptions (ํดํ์ฉ) | |
# ---------------------------------------------- | |
STYLE_DESCRIPTIONS = { | |
"3D_Chibi": "๊ท์ฌ์ด SD ์บ๋ฆญํฐ์ 3D ๋๋", | |
"American_Cartoon": "๊ณ ์ ์ ์ธ ๋ฏธ๊ตญ ์นดํฐ ์คํ์ผ", | |
"Chinese_Ink": "์๋ฌตํ์ ๋ฒ์ง๊ณผ ๋๋ด ํํ", | |
"Clay_Toy": "์ฐฐํยทํ๋ผ์คํด ์ฅ๋๊ฐ่ณช", | |
"Fabric": "์ฌ์ ยทํจ๋ธ๋ฆญ ์ง๊ฐ", | |
"Ghibli": "์ง๋ธ๋ฆฌํ ๋ฐ๋ปํ ์๊ฐ & ์ฐํ์ ", | |
"Irasutoya": "์ผ๋ฌ์คํ ์ผ ๋ฏธ๋๋ฉ ํ๋ฉด ๊ทธ๋ฆผ", | |
"Jojo": "์ฃ ์ฃ ์ ๊ธฐ๋ฌํ ๋ชจํ ๋ง๊ฐ ํฐ์น", | |
"Oil_Painting": "์ ํ ๋ถํฐ์น์ ์ง๊ฐ", | |
"Pixel": "16/32โbit ๋ ํธ๋ก ํฝ์ ์ํธ", | |
"Snoopy": "ํผ๋์ธ ์คํธ๋ฆฝ(์ค๋ํผ) ์คํ์ผ", | |
"Poly": "๋ก์ฐํด๋ฆฌ 3D ๊ธฐํํ์ ์คํ์ผ", | |
"LEGO": "๋ ๊ณ ๋ธ๋ก ์กฐ๋ฆฝ ์คํ์ผ", | |
"Origami": "์ข ์ด์ ๊ธฐ ์ง๊ฐยท๊ฐ๋", | |
"Pop_Art": "ํ์ํธ์ ์ ๋ช ํ ์๊ฐ๊ณผ ๋ํธ", | |
"Van_Gogh": "๋ฐ ๊ณ ํ์ ๊ตต์ ์ํ์คํ ", | |
"Paper_Cutting": "์ข ์ด ์ค๋ฆฌ๊ธฐ ์ค๋ฃจ์ฃ", | |
"Line": "๊น๋ํ ๋ผ์ธ ๋๋ก์", | |
"Vector": "๋ฒกํฐ ๊ทธ๋ํฝยทํ๋ซ ๋์์ธ", | |
"Picasso": "ํผ์นด์์ ์ ์ฒด์ฃผ์ ํ๋น์ฆ", | |
"Macaron": "ํ์คํ ๋ง์นด๋กฑํค ๋ถ๋๋ฌ์", | |
"Rick_Morty": "๋ฆญ ์ค ๋ชจํฐ ์ ๋๋ฉ์ด์ ์คํ์ผ" | |
} | |
# ์ ์ญ ํ์ดํ๋ผ์ธ | |
pipe = None | |
def get_dtype(): | |
"""์ต์ dtype(bf16 ์ง์ ์ bf16, ์๋๋ฉด fp16) ์ ํ""" | |
if torch.cuda.is_available(): | |
major, _ = torch.cuda.get_device_capability() | |
if major >= 8: # Ada/Hopper GPU๋ bf16 ๋ณธ๊ฒฉ ์ง์ | |
return torch.bfloat16 | |
return torch.float16 | |
def load_pipeline(): | |
"""FluxKontext ํ์ดํ๋ผ์ธ์ (์ง์ฐ)๋ก๋""" | |
global pipe | |
if pipe is None: | |
gr.Info("โฌ๏ธ FLUX.1โKontext ๋ชจ๋ธ์ ๋ค์ด๋ก๋ํฉ๋๋คโฆ") | |
pipe = FluxKontextPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-Kontext-dev", | |
torch_dtype=get_dtype(), | |
resume_download=True, | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe.to(device) | |
# VRAM ์ ์ฝ ๋ชจ๋ | |
if torch.cuda.is_available(): | |
pipe.enable_sequential_cpu_offload() | |
pipe.vae.enable_tiling() | |
return pipe | |
# ์ด๊ธฐ ๋ค์ด๋ก๋ ์๊ฐ ํ๋ณด | |
def style_transfer( | |
input_image, | |
style_name, | |
prompt_suffix, | |
num_inference_steps, | |
guidance_scale, | |
seed, | |
): | |
"""์ ํํ ์คํ์ผ๋ก ์ด๋ฏธ์ง ๋ณํ""" | |
if input_image is None: | |
gr.Warning("๐ผ๏ธ ๋จผ์ ์ด๋ฏธ์ง๋ฅผ ์ ๋ก๋ํ์ธ์!") | |
return None | |
try: | |
pipe = load_pipeline() | |
# ์๋ ๊ณ ์ (0์ด๋ฉด ๋์) | |
generator = None | |
if seed and int(seed) != 0: | |
generator = torch.Generator(device=pipe.device).manual_seed(int(seed)) | |
# ์ด๋ฏธ์ง ๋ก๋ & ๋ฆฌ์ฌ์ด์ฆ | |
img = load_image(input_image) if isinstance(input_image, str) else input_image | |
img = img.convert("RGB").resize((1024, 1024), Image.Resampling.LANCZOS) | |
# LoRA ๋ก๋ฉ | |
lora_file = STYLE_LORA_MAP[style_name] | |
adapter_name = "style" | |
pipe.load_lora_weights( | |
"Owen777/Kontext-Style-Loras", | |
weight_name=lora_file, | |
adapter_name=adapter_name, | |
) | |
pipe.set_adapters([adapter_name], adapter_weights=[1.0]) | |
# ํ๋กฌํํธ ๊ตฌ์ฑ | |
prompt = f"Turn this image into the {style_name.replace('_', ' ')} style." | |
if prompt_suffix and prompt_suffix.strip(): | |
prompt += f" {prompt_suffix.strip()}" | |
gr.Info("๐จ ์ด๋ฏธ์ง๋ฅผ ์์ฑ ์คโฆ") | |
result = pipe( | |
image=img, | |
prompt=prompt, | |
guidance_scale=float(guidance_scale), | |
num_inference_steps=int(num_inference_steps), | |
generator=generator, | |
height=1024, | |
width=1024, | |
) | |
# LoRA ์ธ๋ก๋ ๋ฐ ์บ์ ์ ๋ฆฌ | |
pipe.unload_lora_weights() | |
torch.cuda.empty_cache() | |
return result.images[0] | |
except Exception as e: | |
gr.Error(f"๐จ ์ค๋ฅ: {e}") | |
torch.cuda.empty_cache() | |
return None | |
def update_description(style): | |
return STYLE_DESCRIPTIONS.get(style, "") | |
with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# ๐จ FLUX.1 Kontext Style Transfer | |
FLUX.1โKontextโdev ๋ชจ๋ธ๊ณผ 22๊ฐ์ ๊ณ ํ์ง LoRA๋ก ์ด๋ฏธ์ง๋ฅผ ๋ค์ํ ์์ ์คํ์ผ๋ก ๋ณํํ์ธ์. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(label="Upload Image", type="pil", height=400) | |
style_dropdown = gr.Dropdown( | |
choices=list(STYLE_LORA_MAP.keys()), | |
value="Ghibli", | |
label="Select Style", | |
info="Choose from 22 different artistic styles", | |
) | |
style_info = gr.Textbox( | |
label="Style Description", | |
value=STYLE_DESCRIPTIONS["Ghibli"], | |
interactive=False, | |
lines=2, | |
) | |
prompt_suffix = gr.Textbox( | |
label="Additional Instructions (Optional)", | |
placeholder="์: 'make it more colorful', 'add dramatic lighting' โฆ", | |
lines=2, | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
num_steps = gr.Slider( | |
10, | |
50, | |
value=24, | |
step=1, | |
label="Inference Steps", | |
info="๋ ๋์์๋ก ํ์งโ ์๋โ", | |
) | |
guidance = gr.Slider( | |
1.0, | |
7.5, | |
value=2.5, | |
step=0.1, | |
label="Guidance Scale", | |
info="ํ๋กฌํํธ ์ค์ ์ ๋", | |
) | |
seed = gr.Number(label="Seed (0 = Random)", value=0) | |
generate_btn = gr.Button( | |
"๐จ Transform Image", variant="primary", size="lg" | |
) | |
with gr.Column(scale=1): | |
output_image = gr.Image(label="Styled Result", type="pil", height=400) | |
gr.Markdown( | |
"""### ๐ก Tips: | |
- ๋ชจ๋ ์ด๋ฏธ์ง๋ 1024ร1024๋ก ๋ฆฌ์ฌ์ด์ฆ๋ฉ๋๋ค. | |
- ์ฒซ ์คํ ์ 7โฏGB ๋ชจ๋ธ ๋ค์ด๋ก๋๊ฐ ํ์ํฉ๋๋ค. | |
- ์คํ์ผ ๋ณํ์ ์ฝ 30โ60โฏ์ด ์์๋ฉ๋๋ค. | |
- ๋ค๋ฅธ ์คํ์ผ๋ ์ํํด ๋ณด์ธ์!""" | |
) | |
# ์ด๋ฒคํธ ๋ฐ์ธ๋ฉ | |
style_dropdown.change(update_description, [style_dropdown], [style_info]) | |
generate_btn.click( | |
style_transfer, | |
inputs=[ | |
input_image, | |
style_dropdown, | |
prompt_suffix, | |
num_steps, | |
guidance, | |
seed, | |
], | |
outputs=[output_image], | |
) | |
gr.Markdown( | |
""" | |
--- | |
Created with โค๏ธ by [BlackโForest Labs](https://huggingface.co/black-forest-labs) & | |
[Owen777/KontextโStyleโLoras](https://huggingface.co/Owen777/Kontext-Style-Loras) | |
""" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |