Spaces:
ginipick
/
Running on Zero

FitGen / app.py
ginipick's picture
Update app.py
235a7ab verified
raw
history blame
19.9 kB
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download, login
from leffa.transform import LeffaTransform
from leffa.model import LeffaModel
from leffa.inference import LeffaInference
from utils.garment_agnostic_mask_predictor import AutoMasker
from utils.densepose_predictor import DensePosePredictor
from utils.utils import resize_and_center
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import pipeline
import gradio as gr
import os
import random
import gc
# ์ƒ์ˆ˜ ์ •์˜
MAX_SEED = 2**32 - 1
BASE_MODEL = "black-forest-labs/FLUX.1-dev"
MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style"
CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA"
# ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ๋ฅผ ์œ„ํ•œ ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ
def safe_model_call(func):
def wrapper(*args, **kwargs):
try:
clear_memory()
result = func(*args, **kwargs)
clear_memory()
return result
except Exception as e:
clear_memory()
print(f"Error in {func.__name__}: {str(e)}")
raise
return wrapper
# ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ํ•จ์ˆ˜ ์ˆ˜์ •
def clear_memory():
gc.collect()
if torch.cuda.is_available() and torch.cuda.current_device() >= 0:
torch.cuda.empty_cache()
def setup_environment():
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN not found in environment variables")
login(token=HF_TOKEN)
return HF_TOKEN
@spaces.GPU()
def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
try:
# ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
pipe = DiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
)
pipe.to("cuda")
# LoRA ์„ค์ •
if mode == "Generate Model":
pipe.load_lora_weights(MODEL_LORA_REPO)
trigger_word = "fashion photography, professional model"
else:
pipe.load_lora_weights(CLOTHES_LORA_REPO)
trigger_word = "upper clothing, fashion item"
# ์ƒ์„ฑ ์„ค์ •
generator = torch.Generator("cuda").manual_seed(seed if seed is not None else torch.randint(0, 2**32 - 1, (1,)).item())
# ์ด๋ฏธ์ง€ ์ƒ์„ฑ
with torch.inference_mode():
result = pipe(
prompt=f"{prompt} {trigger_word}",
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
cross_attention_kwargs={"scale": lora_scale},
).images[0]
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
del pipe
clear_memory()
return result, seed
except Exception as e:
clear_memory()
raise gr.Error(f"Generation failed: {str(e)}")
# ์ „์—ญ ๋ณ€์ˆ˜ ์ดˆ๊ธฐํ™”
fashion_pipe = None
translator = None
mask_predictor = None
densepose_predictor = None
vt_model = None
pt_model = None
vt_inference = None
pt_inference = None
device = None
HF_TOKEN = None
# ํ™˜๊ฒฝ ์„ค์ • ์‹คํ–‰
setup_environment()
@spaces.GPU()
def initialize_fashion_pipe():
global fashion_pipe
if fashion_pipe is None:
fashion_pipe = DiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
use_auth_token=HF_TOKEN
).to("cuda")
try:
fashion_pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
print(f"Warning: Could not enable memory efficient attention: {e}")
return fashion_pipe
def setup():
# Leffa ์ฒดํฌํฌ์ธํŠธ ๋‹ค์šด๋กœ๋“œ๋งŒ ์ˆ˜ํ–‰
snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
@spaces.GPU()
def get_translator():
global translator
if translator is None:
translator = pipeline("translation",
model="Helsinki-NLP/opus-mt-ko-en",
device="cuda")
return translator
@safe_model_call
def get_mask_predictor():
global mask_predictor
if mask_predictor is None:
mask_predictor = AutoMasker(
densepose_path="./ckpts/densepose",
schp_path="./ckpts/schp",
)
return mask_predictor
@safe_model_call
def get_densepose_predictor():
global densepose_predictor
if densepose_predictor is None:
densepose_predictor = DensePosePredictor(
config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
weights_path="./ckpts/densepose/model_final_162be9.pkl",
)
return densepose_predictor
@safe_model_call
def get_vt_model():
global vt_model, vt_inference
if vt_model is None:
vt_model = LeffaModel(
pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
pretrained_model="./ckpts/virtual_tryon.pth"
)
vt_model = vt_model.half().to(device)
vt_inference = LeffaInference(model=vt_model)
return vt_model, vt_inference
@safe_model_call
def get_pt_model():
global pt_model, pt_inference
if pt_model is None:
pt_model = LeffaModel(
pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
pretrained_model="./ckpts/pose_transfer.pth"
)
pt_model = pt_model.half().to(device)
pt_inference = LeffaInference(model=pt_model)
return pt_model, pt_inference
def load_lora(pipe, lora_path):
try:
pipe.unload_lora_weights()
except:
pass
try:
pipe.load_lora_weights(lora_path)
return pipe
except Exception as e:
print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}")
return pipe
@spaces.GPU()
def get_mask_predictor():
global mask_predictor
if mask_predictor is None:
mask_predictor = AutoMasker(
densepose_path="./ckpts/densepose",
schp_path="./ckpts/schp",
)
return mask_predictor
# ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜
def contains_korean(text):
return any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text)
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜ ์ˆ˜์ •
@spaces.GPU()
def initialize_fashion_pipe():
try:
pipe = DiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
safety_checker=None,
requires_safety_checker=False
).to("cuda")
pipe.enable_model_cpu_offload()
return pipe
except Exception as e:
print(f"Error initializing fashion pipe: {e}")
raise
# ์ƒ์„ฑ ํ•จ์ˆ˜ ์ˆ˜์ •
@spaces.GPU()
def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
try:
# ํ•œ๊ธ€ ์ฒ˜๋ฆฌ
if contains_korean(prompt):
with torch.inference_mode():
translator = get_translator()
translated = translator(prompt)[0]['translation_text']
actual_prompt = translated
else:
actual_prompt = prompt
# ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
pipe = initialize_fashion_pipe()
# LoRA ์„ค์ •
if mode == "Generate Model":
pipe.load_lora_weights(MODEL_LORA_REPO)
trigger_word = "fashion photography, professional model"
else:
pipe.load_lora_weights(CLOTHES_LORA_REPO)
trigger_word = "upper clothing, fashion item"
# ํŒŒ๋ผ๋ฏธํ„ฐ ์ œํ•œ
width = min(width, 768)
height = min(height, 768)
steps = min(steps, 30)
# ์‹œ๋“œ ์„ค์ •
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator("cuda").manual_seed(seed)
# ์ด๋ฏธ์ง€ ์ƒ์„ฑ
with torch.inference_mode():
output = pipe(
prompt=f"{actual_prompt} {trigger_word}",
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
cross_attention_kwargs={"scale": lora_scale},
)
image = output.images[0]
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
del pipe
torch.cuda.empty_cache()
gc.collect()
return image, seed
except Exception as e:
print(f"Error in generate_fashion: {str(e)}")
raise gr.Error(f"Generation failed: {str(e)}")
@safe_model_call
def leffa_predict(src_image_path, ref_image_path, control_type):
try:
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
if control_type == "virtual_tryon":
model, inference = get_vt_model()
else:
model, inference = get_pt_model()
mask_pred = get_mask_predictor()
dense_pred = get_densepose_predictor()
# ์ด๋ฏธ์ง€ ๋กœ๋“œ ๋ฐ ์ „์ฒ˜๋ฆฌ
src_image = Image.open(src_image_path)
ref_image = Image.open(ref_image_path)
src_image = resize_and_center(src_image, 768, 1024)
ref_image = resize_and_center(ref_image, 768, 1024)
src_image_array = np.array(src_image)
ref_image_array = np.array(ref_image)
# Mask ์ƒ์„ฑ
if control_type == "virtual_tryon":
src_image = src_image.convert("RGB")
mask = mask_pred(src_image, "upper")["mask"]
else:
mask = Image.fromarray(np.ones_like(src_image_array) * 255)
# DensePose ์˜ˆ์ธก
src_image_iuv_array = dense_pred.predict_iuv(src_image_array)
src_image_seg_array = dense_pred.predict_seg(src_image_array)
if control_type == "virtual_tryon":
densepose = Image.fromarray(src_image_seg_array)
else:
densepose = Image.fromarray(src_image_iuv_array)
# Leffa ๋ณ€ํ™˜ ๋ฐ ์ถ”๋ก 
transform = LeffaTransform()
data = {
"src_image": [src_image],
"ref_image": [ref_image],
"mask": [mask],
"densepose": [densepose],
}
data = transform(data)
output = inference(data)
return np.array(output["generated_image"][0])
except Exception as e:
print(f"Error in leffa_predict: {str(e)}")
raise
@safe_model_call
def leffa_predict_vt(src_image_path, ref_image_path):
return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
@safe_model_call
def leffa_predict_pt(src_image_path, ref_image_path):
return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
# ์ดˆ๊ธฐ ์„ค์ • ์‹คํ–‰
setup()
# Gradio ์ธํ„ฐํŽ˜์ด์Šค
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
gr.Markdown("# ๐ŸŽญ FitGen:Fashion Studio & Virtual Try-on")
with gr.Tabs():
# ํŒจ์…˜ ์ƒ์„ฑ ํƒญ
# ํŒจ์…˜ ์ƒ์„ฑ ํƒญ
with gr.Tab("Fashion Generation"):
with gr.Column():
mode = gr.Radio(
choices=["Generate Model", "Generate Clothes"],
label="Generation Mode",
value="Generate Model"
)
# ์˜ˆ์ œ ํ”„๋กฌํ”„ํŠธ ์„ค์ •
example_model_prompts = [
"professional fashion model, full body shot, standing pose, natural lighting, studio background, high fashion, elegant pose",
"fashion model portrait, upper body, confident pose, fashion photography, neutral background, professional lighting",
"stylish fashion model, three-quarter view, editorial pose, high-end fashion magazine style, minimal background"
]
example_clothes_prompts = [
"luxury designer sweater, cashmere material, cream color, cable knit pattern, high-end fashion, product photography",
"elegant business blazer, tailored fit, charcoal grey, premium wool fabric, professional wear",
"modern streetwear hoodie, oversized fit, minimalist design, premium cotton, urban style"
]
prompt = gr.TextArea(
label="Fashion Description (ํ•œ๊ธ€ ๋˜๋Š” ์˜์–ด)",
placeholder="ํŒจ์…˜ ๋ชจ๋ธ์ด๋‚˜ ์˜๋ฅ˜๋ฅผ ์„ค๋ช…ํ•˜์„ธ์š”..."
)
# ์˜ˆ์ œ ์„น์…˜ ์ถ”๊ฐ€
gr.Examples(
examples=example_model_prompts + example_clothes_prompts,
inputs=prompt,
label="Example Prompts"
)
with gr.Row():
with gr.Column():
result = gr.Image(label="Generated Result")
generate_button = gr.Button("Generate Fashion")
with gr.Accordion("Advanced Options", open=False):
with gr.Group():
with gr.Row():
with gr.Column():
cfg_scale = gr.Slider(
label="CFG Scale",
minimum=1,
maximum=20,
step=0.5,
value=7.0
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=50, # ์ตœ๋Œ€๊ฐ’ ๊ฐ์†Œ
step=1,
value=30
)
lora_scale = gr.Slider(
label="LoRA Scale",
minimum=0,
maximum=1,
step=0.01,
value=0.85
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=1024, # ์ตœ๋Œ€๊ฐ’ ๊ฐ์†Œ
step=64,
value=512
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=1024, # ์ตœ๋Œ€๊ฐ’ ๊ฐ์†Œ
step=64,
value=768
)
with gr.Row():
randomize_seed = gr.Checkbox(
True,
label="Randomize seed"
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42
)
# ๊ฐ€์ƒ ํ”ผํŒ… ํƒญ
with gr.Tab("Virtual Try-on"):
with gr.Row():
with gr.Column():
gr.Markdown("#### Person Image")
vt_src_image = gr.Image(
sources=["upload"],
type="filepath",
label="Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=vt_src_image,
examples_per_page=5,
examples=["a1.webp",
"a2.webp",
"a3.webp",
"a4.webp",
"a5.webp"]
)
with gr.Column():
gr.Markdown("#### Garment Image")
vt_ref_image = gr.Image(
sources=["upload"],
type="filepath",
label="Garment Image",
width=512,
height=512,
)
gr.Examples(
inputs=vt_ref_image,
examples_per_page=5,
examples=["b1.webp",
"b2.webp",
"b3.webp",
"b4.webp",
"b5.webp"]
)
with gr.Column():
gr.Markdown("#### Generated Image")
vt_gen_image = gr.Image(
label="Generated Image",
width=512,
height=512,
)
vt_gen_button = gr.Button("Try-on")
# ํฌ์ฆˆ ์ „์†ก ํƒญ
with gr.Tab("Pose Transfer"):
with gr.Row():
with gr.Column():
gr.Markdown("#### Person Image")
pt_ref_image = gr.Image(
sources=["upload"],
type="filepath",
label="Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=pt_ref_image,
examples_per_page=5,
examples=["a1.webp",
"a2.webp",
"a3.webp",
"a4.webp",
"a5.webp"]
)
with gr.Column():
gr.Markdown("#### Target Pose Person Image")
pt_src_image = gr.Image(
sources=["upload"],
type="filepath",
label="Target Pose Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=pt_src_image,
examples_per_page=5,
examples=["d1.webp",
"d2.webp",
"d3.webp",
"d4.webp",
"d5.webp"]
)
with gr.Column():
gr.Markdown("#### Generated Image")
pt_gen_image = gr.Image(
label="Generated Image",
width=512,
height=512,
)
pose_transfer_gen_button = gr.Button("Generate")
# ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
generate_button.click(
generate_fashion,
inputs=[prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],
outputs=[result, seed]
)
vt_gen_button.click(
fn=leffa_predict_vt,
inputs=[vt_src_image, vt_ref_image],
outputs=[vt_gen_image]
)
pose_transfer_gen_button.click(
fn=leffa_predict_pt,
inputs=[pt_src_image, pt_ref_image],
outputs=[pt_gen_image]
)
if __name__ == "__main__":
# ํ™˜๊ฒฝ ์„ค์ •
setup_environment()
# ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ ๋ฐ ์‹คํ–‰
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)