seawolf2357's picture
Update app.py
8a86340 verified
raw
history blame
8.84 kB
import gradio as gr
import spaces
import torch
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from PIL import Image
import os
# ----------------------------------------------
# Style โ†’ LoRA file mapping
# ----------------------------------------------
STYLE_LORA_MAP = {
"3D_Chibi": "3D_Chibi_lora_weights.safetensors",
"American_Cartoon": "American_Cartoon_lora_weights.safetensors",
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
"Clay_Toy": "Clay_Toy_lora_weights.safetensors",
"Fabric": "Fabric_lora_weights.safetensors",
"Ghibli": "Ghibli_lora_weights.safetensors",
"Irasutoya": "Irasutoya_lora_weights.safetensors",
"Jojo": "Jojo_lora_weights.safetensors",
"Oil_Painting": "Oil_Painting_lora_weights.safetensors",
"Pixel": "Pixel_lora_weights.safetensors",
"Snoopy": "Snoopy_lora_weights.safetensors",
"Poly": "Poly_lora_weights.safetensors",
"LEGO": "LEGO_lora_weights.safetensors",
"Origami": "Origami_lora_weights.safetensors",
"Pop_Art": "Pop_Art_lora_weights.safetensors",
"Van_Gogh": "Van_Gogh_lora_weights.safetensors",
"Paper_Cutting": "Paper_Cutting_lora_weights.safetensors",
"Line": "Line_lora_weights.safetensors",
"Vector": "Vector_lora_weights.safetensors",
"Picasso": "Picasso_lora_weights.safetensors",
"Macaron": "Macaron_lora_weights.safetensors",
"Rick_Morty": "Rick_Morty_lora_weights.safetensors"
}
# ----------------------------------------------
# Style descriptions (ํˆดํŒ์šฉ)
# ----------------------------------------------
STYLE_DESCRIPTIONS = {
"3D_Chibi": "๊ท€์—ฌ์šด SD ์บ๋ฆญํ„ฐ์˜ 3D ๋А๋‚Œ",
"American_Cartoon": "๊ณ ์ „์ ์ธ ๋ฏธ๊ตญ ์นดํˆฐ ์Šคํƒ€์ผ",
"Chinese_Ink": "์ˆ˜๋ฌตํ™”์˜ ๋ฒˆ์ง๊ณผ ๋†๋‹ด ํ‘œํ˜„",
"Clay_Toy": "์ฐฐํ™ยทํ”Œ๋ผ์Šคํ‹ด ์žฅ๋‚œ๊ฐ่ณช",
"Fabric": "์„ฌ์œ ยทํŒจ๋ธŒ๋ฆญ ์งˆ๊ฐ",
"Ghibli": "์ง€๋ธŒ๋ฆฌํ’ ๋”ฐ๋œปํ•œ ์ƒ‰๊ฐ & ์—ฐํ•„์„ ",
"Irasutoya": "์ผ๋Ÿฌ์Šคํ† ์•ผ ๋ฏธ๋‹ˆ๋ฉ€ ํ‰๋ฉด ๊ทธ๋ฆผ",
"Jojo": "์ฃ ์ฃ ์˜ ๊ธฐ๋ฌ˜ํ•œ ๋ชจํ—˜ ๋ง๊ฐ€ ํ„ฐ์น˜",
"Oil_Painting": "์œ ํ™” ๋ถ“ํ„ฐ์น˜์™€ ์งˆ๊ฐ",
"Pixel": "16/32โ€‘bit ๋ ˆํŠธ๋กœ ํ”ฝ์…€์•„ํŠธ",
"Snoopy": "ํ”ผ๋„ˆ์ธ  ์ŠคํŠธ๋ฆฝ(์Šค๋ˆ„ํ”ผ) ์Šคํƒ€์ผ",
"Poly": "๋กœ์šฐํด๋ฆฌ 3D ๊ธฐํ•˜ํ•™์  ์Šคํƒ€์ผ",
"LEGO": "๋ ˆ๊ณ  ๋ธ”๋ก ์กฐ๋ฆฝ ์Šคํƒ€์ผ",
"Origami": "์ข…์ด์ ‘๊ธฐ ์งˆ๊ฐยท๊ฐ๋„",
"Pop_Art": "ํŒ์•„ํŠธ์˜ ์„ ๋ช…ํ•œ ์ƒ‰๊ฐ๊ณผ ๋„ํŠธ",
"Van_Gogh": "๋ฐ˜ ๊ณ ํ์˜ ๊ตต์€ ์ž„ํŒŒ์Šคํ† ",
"Paper_Cutting": "์ข…์ด ์˜ค๋ฆฌ๊ธฐ ์‹ค๋ฃจ์—ฃ",
"Line": "๊น”๋”ํ•œ ๋ผ์ธ ๋“œ๋กœ์ž‰",
"Vector": "๋ฒกํ„ฐ ๊ทธ๋ž˜ํ”ฝยทํ”Œ๋žซ ๋””์ž์ธ",
"Picasso": "ํ”ผ์นด์†Œ์‹ ์ž…์ฒด์ฃผ์˜ ํ๋น„์ฆ˜",
"Macaron": "ํŒŒ์Šคํ…” ๋งˆ์นด๋กฑํ†ค ๋ถ€๋“œ๋Ÿฌ์›€",
"Rick_Morty": "๋ฆญ ์•ค ๋ชจํ‹ฐ ์• ๋‹ˆ๋ฉ”์ด์…˜ ์Šคํƒ€์ผ"
}
# ์ „์—ญ ํŒŒ์ดํ”„๋ผ์ธ
pipe = None
def get_dtype():
"""์ตœ์  dtype(bf16 ์ง€์› ์‹œ bf16, ์•„๋‹ˆ๋ฉด fp16) ์„ ํƒ"""
if torch.cuda.is_available():
major, _ = torch.cuda.get_device_capability()
if major >= 8: # Ada/Hopper GPU๋Š” bf16 ๋ณธ๊ฒฉ ์ง€์›
return torch.bfloat16
return torch.float16
def load_pipeline():
"""FluxKontext ํŒŒ์ดํ”„๋ผ์ธ์„ (์ง€์—ฐ)๋กœ๋“œ"""
global pipe
if pipe is None:
gr.Info("โฌ‡๏ธ FLUX.1โ€‘Kontext ๋ชจ๋ธ์„ ๋‹ค์šด๋กœ๋“œํ•ฉ๋‹ˆ๋‹คโ€ฆ")
pipe = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev",
torch_dtype=get_dtype(),
resume_download=True,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)
# VRAM ์ ˆ์•ฝ ๋ชจ๋“œ
if torch.cuda.is_available():
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_tiling()
return pipe
@spaces.GPU(duration=600) # ์ดˆ๊ธฐ ๋‹ค์šด๋กœ๋“œ ์‹œ๊ฐ„ ํ™•๋ณด
def style_transfer(
input_image,
style_name,
prompt_suffix,
num_inference_steps,
guidance_scale,
seed,
):
"""์„ ํƒํ•œ ์Šคํƒ€์ผ๋กœ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜"""
if input_image is None:
gr.Warning("๐Ÿ–ผ๏ธ ๋จผ์ € ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”!")
return None
try:
pipe = load_pipeline()
# ์‹œ๋“œ ๊ณ ์ • (0์ด๋ฉด ๋‚œ์ˆ˜)
generator = None
if seed and int(seed) != 0:
generator = torch.Generator(device=pipe.device).manual_seed(int(seed))
# ์ด๋ฏธ์ง€ ๋กœ๋“œ & ๋ฆฌ์‚ฌ์ด์ฆˆ
img = load_image(input_image) if isinstance(input_image, str) else input_image
img = img.convert("RGB").resize((1024, 1024), Image.Resampling.LANCZOS)
# LoRA ๋กœ๋”ฉ
lora_file = STYLE_LORA_MAP[style_name]
adapter_name = "style"
pipe.load_lora_weights(
"Owen777/Kontext-Style-Loras",
weight_name=lora_file,
adapter_name=adapter_name,
)
pipe.set_adapters([adapter_name], adapter_weights=[1.0])
# ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
prompt = f"Turn this image into the {style_name.replace('_', ' ')} style."
if prompt_suffix and prompt_suffix.strip():
prompt += f" {prompt_suffix.strip()}"
gr.Info("๐ŸŽจ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑ ์ค‘โ€ฆ")
result = pipe(
image=img,
prompt=prompt,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_inference_steps),
generator=generator,
height=1024,
width=1024,
)
# LoRA ์–ธ๋กœ๋“œ ๋ฐ ์บ์‹œ ์ •๋ฆฌ
pipe.unload_lora_weights()
torch.cuda.empty_cache()
return result.images[0]
except Exception as e:
gr.Error(f"๐Ÿšจ ์˜ค๋ฅ˜: {e}")
torch.cuda.empty_cache()
return None
def update_description(style):
return STYLE_DESCRIPTIONS.get(style, "")
with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ๐ŸŽจ FLUX.1 Kontext Style Transfer
FLUX.1โ€‘Kontextโ€‘dev ๋ชจ๋ธ๊ณผ 22๊ฐœ์˜ ๊ณ ํ’ˆ์งˆ LoRA๋กœ ์ด๋ฏธ์ง€๋ฅผ ๋‹ค์–‘ํ•œ ์˜ˆ์ˆ  ์Šคํƒ€์ผ๋กœ ๋ณ€ํ™˜ํ•˜์„ธ์š”.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Upload Image", type="pil", height=400)
style_dropdown = gr.Dropdown(
choices=list(STYLE_LORA_MAP.keys()),
value="Ghibli",
label="Select Style",
info="Choose from 22 different artistic styles",
)
style_info = gr.Textbox(
label="Style Description",
value=STYLE_DESCRIPTIONS["Ghibli"],
interactive=False,
lines=2,
)
prompt_suffix = gr.Textbox(
label="Additional Instructions (Optional)",
placeholder="์˜ˆ: 'make it more colorful', 'add dramatic lighting' โ€ฆ",
lines=2,
)
with gr.Accordion("Advanced Settings", open=False):
num_steps = gr.Slider(
10,
50,
value=24,
step=1,
label="Inference Steps",
info="๋” ๋†’์„์ˆ˜๋ก ํ’ˆ์งˆโ†‘ ์†๋„โ†“",
)
guidance = gr.Slider(
1.0,
7.5,
value=2.5,
step=0.1,
label="Guidance Scale",
info="ํ”„๋กฌํ”„ํŠธ ์ค€์ˆ˜ ์ •๋„",
)
seed = gr.Number(label="Seed (0 = Random)", value=0)
generate_btn = gr.Button(
"๐ŸŽจ Transform Image", variant="primary", size="lg"
)
with gr.Column(scale=1):
output_image = gr.Image(label="Styled Result", type="pil", height=400)
gr.Markdown(
"""### ๐Ÿ’ก Tips:
- ๋ชจ๋“  ์ด๋ฏธ์ง€๋Š” 1024ร—1024๋กœ ๋ฆฌ์‚ฌ์ด์ฆˆ๋ฉ๋‹ˆ๋‹ค.
- ์ฒซ ์‹คํ–‰ ์‹œ 7โ€ฏGB ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
- ์Šคํƒ€์ผ ๋ณ€ํ™˜์€ ์•ฝ 30โ€‘60โ€ฏ์ดˆ ์†Œ์š”๋ฉ๋‹ˆ๋‹ค.
- ๋‹ค๋ฅธ ์Šคํƒ€์ผ๋„ ์‹œํ—˜ํ•ด ๋ณด์„ธ์š”!"""
)
# ์ด๋ฒคํŠธ ๋ฐ”์ธ๋”ฉ
style_dropdown.change(update_description, [style_dropdown], [style_info])
generate_btn.click(
style_transfer,
inputs=[
input_image,
style_dropdown,
prompt_suffix,
num_steps,
guidance,
seed,
],
outputs=[output_image],
)
gr.Markdown(
"""
---
Created with โค๏ธ by [Blackโ€‘Forest Labs](https://huggingface.co/black-forest-labs) &
[Owen777/Kontextโ€‘Styleโ€‘Loras](https://huggingface.co/Owen777/Kontext-Style-Loras)
"""
)
if __name__ == "__main__":
demo.launch()