""" FLUX.1 Kontext Style Transfer ============================== Updated: 2025‑07‑12 --------------------------------- 이 스크립트는 Hugging Face **FLUX.1‑Kontext‑dev** 모델과 22 종의 스타일 LoRA 가중치를 이용해 이미지를 다양한 예술 스타일로 변환하는 Gradio 데모입니다. 주요 개선 사항 -------------- 1. **모델 캐싱** – `snapshot_download()`로 실행 시작 시 한 번만 모델과 LoRA를 캐싱해 이후 GPU 잡에서도 재다운로드가 없도록 함. 2. **GPU VRAM 자동 판별** – GPU VRAM이 24 GB 미만이면 `torch.float16` / `enable_sequential_cpu_offload()`를 자동 적용. 3. **단일 로딩 메시지** – Gradio `gr.Info()` 메시지가 최초 1회만 표시되도록 수정. 4. **버그 픽스** – seed 처리, LoRA 언로드, 이미지 리사이즈 로직 등 세부 오류 수정. ------------------------------------------------------------ """ import os import gradio as gr import spaces import torch from huggingface_hub import snapshot_download from diffusers import FluxKontextPipeline from diffusers.utils import load_image from PIL import Image # ------------------------------------------------------------------ # 환경 설정 & 모델 / LoRA 사전 다운로드 # ------------------------------------------------------------------ # 큰 파일을 빠르게 받도록 가속 플래그 활성화 os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") MODEL_ID = "black-forest-labs/FLUX.1-Kontext-dev" LORA_REPO = "Owen777/Kontext-Style-Loras" CACHE_DIR = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface")) # --- 최초 실행 시에만 다운로드(이미 캐시에 있으면 건너뜀) --- MODEL_DIR = snapshot_download( repo_id=MODEL_ID, cache_dir=CACHE_DIR, resume_download=True, token=True # HF 토큰(필요 시 환경변수 HF_TOKEN 지정) ) LORA_DIR = snapshot_download( repo_id=LORA_REPO, cache_dir=CACHE_DIR, resume_download=True, token=True ) # ------------------------------------------------------------------ # 스타일 → LoRA 파일 매핑 & 설명 # ------------------------------------------------------------------ STYLE_LORA_MAP = { "3D_Chibi": "3D_Chibi_lora_weights.safetensors", "American_Cartoon": "American_Cartoon_lora_weights.safetensors", "Chinese_Ink": "Chinese_Ink_lora_weights.safetensors", "Clay_Toy": "Clay_Toy_lora_weights.safetensors", "Fabric": "Fabric_lora_weights.safetensors", "Ghibli": "Ghibli_lora_weights.safetensors", "Irasutoya": "Irasutoya_lora_weights.safetensors", "Jojo": "Jojo_lora_weights.safetensors", "Oil_Painting": "Oil_Painting_lora_weights.safetensors", "Pixel": "Pixel_lora_weights.safetensors", "Snoopy": "Snoopy_lora_weights.safetensors", "Poly": "Poly_lora_weights.safetensors", "LEGO": "LEGO_lora_weights.safetensors", "Origami": "Origami_lora_weights.safetensors", "Pop_Art": "Pop_Art_lora_weights.safetensors", "Van_Gogh": "Van_Gogh_lora_weights.safetensors", "Paper_Cutting": "Paper_Cutting_lora_weights.safetensors", "Line": "Line_lora_weights.safetensors", "Vector": "Vector_lora_weights.safetensors", "Picasso": "Picasso_lora_weights.safetensors", "Macaron": "Macaron_lora_weights.safetensors", "Rick_Morty": "Rick_Morty_lora_weights.safetensors", } STYLE_DESCRIPTIONS = { "3D_Chibi": "Cute, miniature 3D character style with big heads", "American_Cartoon": "Classic American animation style", "Chinese_Ink": "Traditional Chinese ink painting aesthetic", "Clay_Toy": "Playful clay/plasticine toy appearance", "Fabric": "Soft, textile-like rendering", "Ghibli": "Studio Ghibli's distinctive anime style", "Irasutoya": "Simple, flat Japanese illustration style", "Jojo": "JoJo's Bizarre Adventure manga style", "Oil_Painting": "Classic oil painting texture and strokes", "Pixel": "Retro pixel art style", "Snoopy": "Peanuts comic strip style", "Poly": "Low-poly 3D geometric style", "LEGO": "LEGO brick construction style", "Origami": "Paper folding art style", "Pop_Art": "Bold, colorful pop art style", "Van_Gogh": "Van Gogh's expressive brushstroke style", "Paper_Cutting": "Paper cut-out art style", "Line": "Clean line art/sketch style", "Vector": "Clean vector graphics style", "Picasso": "Cubist art style inspired by Picasso", "Macaron": "Soft, pastel macaron-like style", "Rick_Morty": "Rick and Morty cartoon style", } # ------------------------------------------------------------------ # 파이프라인 로더 (단일 인스턴스) # ------------------------------------------------------------------ _pipeline = None # 내부 글로벌 캐시 def load_pipeline(): """Load (or return cached) FluxKontextPipeline.""" global _pipeline if _pipeline is not None: return _pipeline # VRAM이 24 GB 미만이면 FP16 사용 + CPU 오프로딩 dtype = torch.bfloat16 vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3 if vram_gb < 24: dtype = torch.float16 gr.Info("FLUX.1‑Kontext 파이프라인 로딩 중… (최초 1회)") pipe = FluxKontextPipeline.from_pretrained( MODEL_DIR, torch_dtype=dtype, local_files_only=True, ) pipe.to("cuda") if vram_gb < 24: pipe.enable_sequential_cpu_offload() else: pipe.enable_model_cpu_offload() _pipeline = pipe return _pipeline # ------------------------------------------------------------------ # 스타일 변환 함수 (Spaces GPU 잡) # ------------------------------------------------------------------ @spaces.GPU(duration=600) def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps, guidance_scale, seed): """Apply selected style to the uploaded image.""" if input_image is None: gr.Warning("Please upload an image first!") return None try: pipe = load_pipeline() # --- Torch Generator 설정 --- if seed > 0: generator = torch.Generator(device="cuda").manual_seed(int(seed)) else: generator = None # random # --- 입력 이미지 전처리 --- img = input_image if isinstance(input_image, Image.Image) else load_image(input_image) img = img.convert("RGB").resize((1024, 1024), Image.Resampling.LANCZOS) # --- LoRA 로드 --- lora_file = STYLE_LORA_MAP[style_name] adapter_name = "style" pipe.load_lora_weights(LORA_DIR, weight_name=lora_file, adapter_name=adapter_name) pipe.set_adapters([adapter_name], [1.0]) # --- 프롬프트 빌드 --- human_readable_style = style_name.replace("_", " ") prompt = f"Turn this image into the {human_readable_style} style." if prompt_suffix and prompt_suffix.strip(): prompt += f" {prompt_suffix.strip()}" gr.Info("Generating styled image… (24‑60 s)") result = pipe( image=img, prompt=prompt, guidance_scale=float(guidance_scale), num_inference_steps=int(num_inference_steps), generator=generator, height=1024, width=1024, ) # --- LoRA 언로드 & GPU 메모리 해제 --- pipe.unload_lora_weights(adapter_name=adapter_name) torch.cuda.empty_cache() return result.images[0] except Exception as e: torch.cuda.empty_cache() gr.Error(f"Error during style transfer: {e}") return None # ------------------------------------------------------------------ # Gradio UI 정의 # ------------------------------------------------------------------ with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎨 FLUX.1 Kontext Style Transfer 업로드한 이미지를 22 종의 예술 스타일로 변환하세요! (모델 / LoRA는 최초 실행 시에만 다운로드되며, 이후 실행은 빠릅니다.) """) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(label="Upload Image", type="pil", height=400) style_dropdown = gr.Dropdown( choices=list(STYLE_LORA_MAP.keys()), value="Ghibli", label="Select Style", ) style_info = gr.Textbox(label="Style Description", value=STYLE_DESCRIPTIONS["Ghibli"], interactive=False, lines=2) prompt_suffix = gr.Textbox(label="Additional Instructions (Optional)", placeholder="e.g. add dramatic lighting", lines=2) with gr.Accordion("Advanced Settings", open=False): num_steps = gr.Slider(minimum=10, maximum=50, value=24, step=1, label="Inference Steps") guidance = gr.Slider(minimum=1.0, maximum=7.5, value=2.5, step=0.1, label="Guidance Scale") seed = gr.Number(label="Seed (0 = random)", value=42) generate_btn = gr.Button("🎨 Transform Image", variant="primary", size="lg") with gr.Column(scale=1): output_image = gr.Image(label="Styled Result", type="pil", height=400) gr.Markdown(""" ### 💡 Tips * 이미지 크기는 1024×1024로 리사이즈됩니다. * 최초 1회 모델 + LoRA 다운로드 후에는 **캐시**를 사용하므로 10‑20 s 내 완료됩니다. * "Additional Instructions"에 색감·조명·효과 등을 영어로 간단히 적으면 결과를 세밀하게 제어할 수 있습니다. """) # --- 스타일 설명 자동 업데이트 --- def _update_desc(style): return STYLE_DESCRIPTIONS.get(style, "") style_dropdown.change(fn=_update_desc, inputs=[style_dropdown], outputs=[style_info]) # --- 예제 --- gr.Examples( examples=[ ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli", ""], ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi", "make it extra cute"], ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh", "with swirling sky"], ], inputs=[input_image, style_dropdown, prompt_suffix], outputs=output_image, fn=lambda img, style, prompt: style_transfer(img, style, prompt, 24, 2.5, 42), cache_examples=False, ) # --- 버튼 연결 --- generate_btn.click( fn=style_transfer, inputs=[input_image, style_dropdown, prompt_suffix, num_steps, guidance, seed], outputs=output_image, ) gr.Markdown(""" --- **Created with ❤️ by GiniGEN (2025)** """) if __name__ == "__main__": demo.launch(inline=False)