webtoon / app.py
aiqtech's picture
Update app.py
71036e1 verified
raw
history blame
43.9 kB
import tempfile
import time
from collections.abc import Sequence
from typing import Any, cast
import os
from huggingface_hub import login, hf_hub_download
import gradio as gr
import numpy as np
import pillow_heif
import spaces
import torch
from gradio_image_annotation import image_annotator
from gradio_imageslider import ImageSlider
from PIL import Image
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
from refiners.fluxion.utils import no_grad
from refiners.solutions import BoxSegmenter
from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor
from diffusers import FluxPipeline
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import gc
from PIL import Image, ImageDraw, ImageFont
from PIL import Image
from gradio_client import Client, handle_file
import uuid
import gradio as gr
import spaces
import torch
from diffusers import AutoencoderKL, TCDScheduler
from diffusers.models.model_loading_utils import load_state_dict
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
from transformers import pipeline
from controlnet_union import ControlNetModel_Union
from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
def debug_event(event_name, *args):
"""์ด๋ฒคํŠธ ๋””๋ฒ„๊น… ์œ ํ‹ธ๋ฆฌํ‹ฐ"""
print(f"Event '{event_name}' triggered at {time.strftime('%H:%M:%S')}")
print(f"Arguments: {args}")
return args
def clear_memory():
"""๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜"""
gc.collect()
try:
if torch.cuda.is_available():
with torch.cuda.device(0): # ๋ช…์‹œ์ ์œผ๋กœ device 0 ์‚ฌ์šฉ
torch.cuda.empty_cache()
except:
pass
# GPU ์„ค์ •
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
# GPU ์„ค์ •์„ try-except๋กœ ๊ฐ์‹ธ๊ธฐ
if torch.cuda.is_available():
try:
with torch.cuda.device(0):
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
except:
print("Warning: Could not configure CUDA settings")
# ๋ฒˆ์—ญ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
model_name = "Helsinki-NLP/opus-mt-ko-en"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to('cpu')
translator = pipeline("translation", model=model, tokenizer=tokenizer, device=-1)
def translate_to_english(text: str) -> str:
"""ํ•œ๊ธ€ ํ…์ŠคํŠธ๋ฅผ ์˜์–ด๋กœ ๋ฒˆ์—ญ"""
try:
if any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text):
translated = translator(text, max_length=128)[0]['translation_text']
print(f"Translated '{text}' to '{translated}'")
return translated
return text
except Exception as e:
print(f"Translation error: {str(e)}")
return text
BoundingBox = tuple[int, int, int, int]
pillow_heif.register_heif_opener()
pillow_heif.register_avif_opener()
# HF ํ† ํฐ ์„ค์ •
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("Please set the HF_TOKEN environment variable")
try:
login(token=HF_TOKEN)
except Exception as e:
raise ValueError(f"Failed to login to Hugging Face: {str(e)}")
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
segmenter = BoxSegmenter(device="cpu")
segmenter.device = device
segmenter.model = segmenter.model.to(device=segmenter.device)
gd_model_path = "IDEA-Research/grounding-dino-base"
gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_dtype=torch.float32)
gd_model = gd_model.to(device=device)
assert isinstance(gd_model, GroundingDinoForObjectDetection)
# FLUX ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
use_auth_token=HF_TOKEN
)
pipe.enable_attention_slicing(slice_size="auto")
# LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
pipe.load_lora_weights(
hf_hub_download(
"ByteDance/Hyper-SD",
"Hyper-FLUX.1-dev-8steps-lora.safetensors",
use_auth_token=HF_TOKEN
)
)
pipe.fuse_lora(lora_scale=0.125)
# GPU ์„ค์ •์„ try-except๋กœ ๊ฐ์‹ธ๊ธฐ
try:
if torch.cuda.is_available():
pipe = pipe.to("cuda:0") # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
except Exception as e:
print(f"Warning: Could not move pipeline to CUDA: {str(e)}")
#------------------------------- ์ด๋ฏธ์ง€ ์ธํŽ˜์ธํŒ… ----------------------
client = Client("NabeelShar/BiRefNet_for_text_writing")
MODELS = {
"RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
}
config_file = hf_hub_download(
"xinsir/controlnet-union-sdxl-1.0",
filename="config_promax.json",
)
config = ControlNetModel_Union.load_config(config_file)
controlnet_model = ControlNetModel_Union.from_config(config)
model_file = hf_hub_download(
"xinsir/controlnet-union-sdxl-1.0",
filename="diffusion_pytorch_model_promax.safetensors",
)
state_dict = load_state_dict(model_file)
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
)
model.to(device="cuda", dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
).to("cuda")
pipe = StableDiffusionXLFillPipeline.from_pretrained(
"SG161222/RealVisXL_V5.0_Lightning",
torch_dtype=torch.float16,
vae=vae,
controlnet=model,
variant="fp16",
).to("cuda")
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
def translate_if_korean(text):
# ์ž…๋ ฅ๋œ ํ…์ŠคํŠธ๊ฐ€ ํ•œ๊ธ€์„ ํฌํ•จํ•˜๊ณ  ์žˆ๋Š”์ง€ ํ™•์ธ
if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in text):
# ํ•œ๊ธ€์ด ํฌํ•จ๋˜์–ด ์žˆ๋‹ค๋ฉด ๋ฒˆ์—ญ
translated = translator(text)[0]['translation_text']
print(f"Translated prompt: {translated}") # ๋””๋ฒ„๊น…์„ ์œ„ํ•œ ์ถœ๋ ฅ
return translated
return text
@spaces.GPU
def fill_image(prompt, image, model_selection):
# ํ”„๋กฌํ”„ํŠธ ๋ฒˆ์—ญ
translated_prompt = translate_if_korean(prompt)
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(translated_prompt, "cuda", True)
source = image["background"]
mask = image["layers"][0]
alpha_channel = mask.split()[3]
binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
cnet_image = source.copy()
cnet_image.paste(0, (0, 0), binary_mask)
for image in pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
image=cnet_image,
):
yield image, cnet_image
image = image.convert("RGBA")
cnet_image.paste(image, (0, 0), binary_mask)
yield source, cnet_image
def clear_result():
return gr.update(value=None)
def process_inpainting(image, mask_input, prompt):
"""์ด๋ฏธ์ง€ ์ธํŽ˜์ธํŒ… ์ฒ˜๋ฆฌ ํ•จ์ˆ˜"""
try:
if image is None or mask_input is None or not prompt:
raise gr.Error("Please provide image, mask, and prompt")
# ํ”„๋กฌํ”„ํŠธ ๋ฒˆ์—ญ (ํ•œ๊ธ€์ธ ๊ฒฝ์šฐ)
translated_prompt = translate_if_korean(prompt)
# ๋งˆ์Šคํฌ ์ฒ˜๋ฆฌ
source = image
if isinstance(mask_input, dict):
mask = mask_input["layers"][0]
alpha_channel = mask.split()[3]
binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
else:
raise gr.Error("Invalid mask input")
# ์ธํŽ˜์ธํŒ…์„ ์œ„ํ•œ ์ด๋ฏธ์ง€ ์ค€๋น„
cnet_image = source.copy()
cnet_image.paste(0, (0, 0), binary_mask)
# ํ”„๋กฌํ”„ํŠธ ์ž„๋ฒ ๋”ฉ
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(translated_prompt, "cuda", True)
# ์ธํŽ˜์ธํŒ… ์‹คํ–‰
result = None
for image in pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
image=cnet_image,
):
result = image
if result is None:
raise gr.Error("Inpainting failed")
# ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
result = result.convert("RGBA")
cnet_image.paste(result, (0, 0), binary_mask)
return cnet_image
except Exception as e:
print(f"Inpainting error: {str(e)}")
raise gr.Error(f"Inpainting failed: {str(e)}")
finally:
clear_memory()
#--------------- ์ด๋ฏธ์ง€ ์ธํŽ˜์ธํŒ… ๋ ----------------
class timer:
def __init__(self, method_name="timed process"):
self.method = method_name
def __enter__(self):
self.start = time.time()
print(f"{self.method} starts")
def __exit__(self, exc_type, exc_val, exc_tb):
end = time.time()
print(f"{self.method} took {str(round(end - self.start, 2))}s")
def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
if not bboxes:
return None
for bbox in bboxes:
assert len(bbox) == 4
assert all(isinstance(x, int) for x in bbox)
return (
min(bbox[0] for bbox in bboxes),
min(bbox[1] for bbox in bboxes),
max(bbox[2] for bbox in bboxes),
max(bbox[3] for bbox in bboxes),
)
def corners_to_pixels_format(bboxes: torch.Tensor, width: int, height: int) -> torch.Tensor:
x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
return torch.stack((x1.clamp_(0, width), y1.clamp_(0, height), x2.clamp_(0, width), y2.clamp_(0, height)), dim=-1)
def gd_detect(img: Image.Image, prompt: str) -> BoundingBox | None:
inputs = gd_processor(images=img, text=f"{prompt}.", return_tensors="pt").to(device=device)
with no_grad():
outputs = gd_model(**inputs)
width, height = img.size
results: dict[str, Any] = gd_processor.post_process_grounded_object_detection(
outputs,
inputs["input_ids"],
target_sizes=[(height, width)],
)[0]
assert "boxes" in results and isinstance(results["boxes"], torch.Tensor)
bboxes = corners_to_pixels_format(results["boxes"].cpu(), width, height)
return bbox_union(bboxes.numpy().tolist())
def apply_mask(img: Image.Image, mask_img: Image.Image, defringe: bool = True) -> Image.Image:
assert img.size == mask_img.size
img = img.convert("RGB")
mask_img = mask_img.convert("L")
if defringe:
rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0
foreground = cast(np.ndarray[Any, np.dtype[np.uint8]], estimate_foreground_ml(rgb, alpha))
img = Image.fromarray((foreground * 255).astype("uint8"))
result = Image.new("RGBA", img.size)
result.paste(img, (0, 0), mask_img)
return result
def adjust_size_to_multiple_of_8(width: int, height: int) -> tuple[int, int]:
"""์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ฅผ 8์˜ ๋ฐฐ์ˆ˜๋กœ ์กฐ์ •ํ•˜๋Š” ํ•จ์ˆ˜"""
new_width = ((width + 7) // 8) * 8
new_height = ((height + 7) // 8) * 8
return new_width, new_height
def calculate_dimensions(aspect_ratio: str, base_size: int = 512) -> tuple[int, int]:
"""์„ ํƒ๋œ ๋น„์œจ์— ๋”ฐ๋ผ ์ด๋ฏธ์ง€ ํฌ๊ธฐ ๊ณ„์‚ฐ"""
if aspect_ratio == "1:1":
return base_size, base_size
elif aspect_ratio == "16:9":
return base_size * 16 // 9, base_size
elif aspect_ratio == "9:16":
return base_size, base_size * 16 // 9
elif aspect_ratio == "4:3":
return base_size * 4 // 3, base_size
return base_size, base_size
@spaces.GPU(duration=20) # 40์ดˆ์—์„œ 20์ดˆ๋กœ ๊ฐ์†Œ
def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
try:
width, height = calculate_dimensions(aspect_ratio)
width, height = adjust_size_to_multiple_of_8(width, height)
max_size = 768
if width > max_size or height > max_size:
ratio = max_size / max(width, height)
width = int(width * ratio)
height = int(height * ratio)
width, height = adjust_size_to_multiple_of_8(width, height)
with timer("Background generation"):
try:
with torch.inference_mode():
image = pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=8,
guidance_scale=4.0
).images[0]
except Exception as e:
print(f"Pipeline error: {str(e)}")
return Image.new('RGB', (width, height), 'white')
return image
except Exception as e:
print(f"Background generation error: {str(e)}")
return Image.new('RGB', (512, 512), 'white')
def create_position_grid():
return """
<div class="position-grid" style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px; width: 150px; margin: auto;">
<button class="position-btn" data-pos="top-left">โ†–</button>
<button class="position-btn" data-pos="top-center">โ†‘</button>
<button class="position-btn" data-pos="top-right">โ†—</button>
<button class="position-btn" data-pos="middle-left">โ†</button>
<button class="position-btn" data-pos="middle-center">โ€ข</button>
<button class="position-btn" data-pos="middle-right">โ†’</button>
<button class="position-btn" data-pos="bottom-left">โ†™</button>
<button class="position-btn" data-pos="bottom-center" data-default="true">โ†“</button>
<button class="position-btn" data-pos="bottom-right">โ†˜</button>
</div>
"""
def calculate_object_position(position: str, bg_size: tuple[int, int], obj_size: tuple[int, int]) -> tuple[int, int]:
"""์˜ค๋ธŒ์ ํŠธ์˜ ์œ„์น˜ ๊ณ„์‚ฐ"""
bg_width, bg_height = bg_size
obj_width, obj_height = obj_size
positions = {
"top-left": (0, 0),
"top-center": ((bg_width - obj_width) // 2, 0),
"top-right": (bg_width - obj_width, 0),
"middle-left": (0, (bg_height - obj_height) // 2),
"middle-center": ((bg_width - obj_width) // 2, (bg_height - obj_height) // 2),
"middle-right": (bg_width - obj_width, (bg_height - obj_height) // 2),
"bottom-left": (0, bg_height - obj_height),
"bottom-center": ((bg_width - obj_width) // 2, bg_height - obj_height),
"bottom-right": (bg_width - obj_width, bg_height - obj_height)
}
return positions.get(position, positions["bottom-center"])
def resize_object(image: Image.Image, scale_percent: float) -> Image.Image:
"""์˜ค๋ธŒ์ ํŠธ ํฌ๊ธฐ ์กฐ์ •"""
width = int(image.width * scale_percent / 100)
height = int(image.height * scale_percent / 100)
return image.resize((width, height), Image.Resampling.LANCZOS)
def combine_with_background(foreground: Image.Image, background: Image.Image,
position: str = "bottom-center", scale_percent: float = 100) -> Image.Image:
"""์ „๊ฒฝ๊ณผ ๋ฐฐ๊ฒฝ ํ•ฉ์„ฑ ํ•จ์ˆ˜"""
print(f"Combining with position: {position}, scale: {scale_percent}")
result = background.convert('RGBA')
scaled_foreground = resize_object(foreground, scale_percent)
x, y = calculate_object_position(position, result.size, scaled_foreground.size)
print(f"Calculated position coordinates: ({x}, {y})")
result.paste(scaled_foreground, (x, y), scaled_foreground)
return result
@spaces.GPU(duration=30) # 120์ดˆ์—์„œ 30์ดˆ๋กœ ๊ฐ์†Œ
def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Image.Image, BoundingBox | None, list[str]]:
time_log: list[str] = []
try:
if isinstance(prompt, str):
t0 = time.time()
bbox = gd_detect(img, prompt)
time_log.append(f"detect: {time.time() - t0}")
if not bbox:
print(time_log[0])
raise gr.Error("No object detected")
else:
bbox = prompt
t0 = time.time()
mask = segmenter(img, bbox)
time_log.append(f"segment: {time.time() - t0}")
return mask, bbox, time_log
except Exception as e:
print(f"GPU process error: {str(e)}")
raise
def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str | None = None, aspect_ratio: str = "1:1") -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]:
try:
# ์ž…๋ ฅ ์ด๋ฏธ์ง€ ํฌ๊ธฐ ์ œํ•œ
max_size = 1024
if img.width > max_size or img.height > max_size:
ratio = max_size / max(img.width, img.height)
new_size = (int(img.width * ratio), int(img.height * ratio))
img = img.resize(new_size, Image.LANCZOS)
# CUDA ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ์ˆ˜์ •
try:
if torch.cuda.is_available():
current_device = torch.cuda.current_device()
with torch.cuda.device(current_device):
torch.cuda.empty_cache()
except Exception as e:
print(f"CUDA memory management failed: {e}")
with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
mask, bbox, time_log = _gpu_process(img, prompt)
masked_alpha = apply_mask(img, mask, defringe=True)
if bg_prompt:
background = generate_background(bg_prompt, aspect_ratio)
combined = background
else:
combined = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
clear_memory()
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp:
combined.save(temp.name)
return (img, combined, masked_alpha), gr.DownloadButton(value=temp.name, interactive=True)
except Exception as e:
clear_memory()
print(f"Processing error: {str(e)}")
raise gr.Error(f"Processing failed: {str(e)}")
def on_change_bbox(prompts: dict[str, Any] | None):
return gr.update(interactive=prompts is not None)
def on_change_prompt(img: Image.Image | None, prompt: str | None, bg_prompt: str | None = None):
return gr.update(interactive=bool(img and prompt))
def process_prompt(img: Image.Image, prompt: str, bg_prompt: str | None = None,
aspect_ratio: str = "1:1", position: str = "bottom-center",
scale_percent: float = 100) -> tuple[Image.Image, Image.Image]:
try:
if img is None or prompt.strip() == "":
raise gr.Error("Please provide both image and prompt")
print(f"Processing with position: {position}, scale: {scale_percent}") # ๋””๋ฒ„๊น…์šฉ
try:
prompt = translate_to_english(prompt)
if bg_prompt:
bg_prompt = translate_to_english(bg_prompt)
except Exception as e:
print(f"Translation error (continuing with original text): {str(e)}")
results, _ = _process(img, prompt, bg_prompt, aspect_ratio)
if bg_prompt:
try:
print(f"Using position: {position}") # ๋””๋ฒ„๊น…์šฉ
# ์œ„์น˜ ๊ฐ’ ๊ฒ€์ฆ
valid_positions = ["top-left", "top-center", "top-right",
"middle-left", "middle-center", "middle-right",
"bottom-left", "bottom-center", "bottom-right"]
if position not in valid_positions:
position = "bottom-center"
print(f"Invalid position, using default: {position}")
combined = combine_with_background(
foreground=results[2],
background=results[1],
position=position,
scale_percent=scale_percent
)
return combined, results[2]
except Exception as e:
print(f"Combination error: {str(e)}")
return results[1], results[2]
return results[1], results[2] # ๊ธฐ๋ณธ ๋ฐ˜ํ™˜ ์ถ”๊ฐ€
except Exception as e:
print(f"Error in process_prompt: {str(e)}")
raise gr.Error(str(e))
finally:
clear_memory()
def process_bbox(img: Image.Image, box_input: str) -> tuple[Image.Image, Image.Image]:
try:
if img is None or box_input.strip() == "":
raise gr.Error("Please provide both image and bounding box coordinates")
try:
coords = eval(box_input)
if not isinstance(coords, list) or len(coords) != 4:
raise ValueError("Invalid box format")
bbox = tuple(int(x) for x in coords)
except:
raise gr.Error("Invalid box format. Please provide [xmin, ymin, xmax, ymax]")
# Process the image
results, _ = _process(img, bbox)
# ํ•ฉ์„ฑ๋œ ์ด๋ฏธ์ง€์™€ ์ถ”์ถœ๋œ ์ด๋ฏธ์ง€๋งŒ ๋ฐ˜ํ™˜
return results[1], results[2]
except Exception as e:
raise gr.Error(str(e))
# Event handler functions ์ˆ˜์ •
def update_process_button(img, prompt):
return gr.update(
interactive=bool(img and prompt),
variant="primary" if bool(img and prompt) else "secondary"
)
def update_box_button(img, box_input):
try:
if img and box_input:
coords = eval(box_input)
if isinstance(coords, list) and len(coords) == 4:
return gr.update(interactive=True, variant="primary")
return gr.update(interactive=False, variant="secondary")
except:
return gr.update(interactive=False, variant="secondary")
css = """
/* ๊ธฐ๋ณธ ๋ ˆ์ด์•„์›ƒ */
footer {display: none !important}
body {background: #f5f7fa !important}
/* ๋ฉ”์ธ ํƒ€์ดํ‹€ */
.main-title {
text-align: center;
margin: 1.5em auto;
padding: 2em;
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
border-radius: 15px;
box-shadow: 0 8px 16px rgba(0,0,0,0.1);
max-width: 1200px;
}
.main-title h1 {
color: #2196F3;
font-size: 3em;
margin-bottom: 0.5em;
font-weight: 700;
text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
}
.main-title p {
color: #555;
font-size: 1.4em;
line-height: 1.6;
max-width: 800px;
margin: 0 auto;
}
/* ํŒจ๋„ ์Šคํƒ€์ผ๋ง */
.input-panel, .output-panel {
background: white;
padding: 2em;
border-radius: 15px;
box-shadow: 0 4px 12px rgba(0,0,0,0.05);
margin-bottom: 1.5em;
transition: all 0.3s ease;
}
.input-panel:hover, .output-panel:hover {
box-shadow: 0 6px 16px rgba(0,0,0,0.1);
}
/* ์ปจํŠธ๋กค ํŒจ๋„ */
.controls-panel {
background: #f8f9fa;
padding: 1.5em;
border-radius: 12px;
margin: 1.5em 0;
border: 1px solid #e9ecef;
}
/* ์ด๋ฏธ์ง€ ๋””์Šคํ”Œ๋ ˆ์ด */
.image-display {
min-height: 512px;
display: flex;
align-items: center;
justify-content: center;
background: #fafafa;
border-radius: 12px;
margin: 1.5em 0;
border: 2px dashed #e0e0e0;
}
/* ๋ฒ„ํŠผ ์Šคํƒ€์ผ๋ง */
.position-btn {
padding: 12px;
border: 2px solid #ddd;
border-radius: 8px;
background: white;
cursor: pointer;
transition: all 0.2s ease;
width: 48px;
height: 48px;
display: flex;
align-items: center;
justify-content: center;
font-size: 1.2em;
margin: 4px;
}
.position-btn:hover {
background: #e3f2fd;
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
.position-btn.selected {
background-color: #2196F3;
color: white;
border-color: #1976D2;
box-shadow: 0 4px 12px rgba(33,150,243,0.3);
}
/* ๊ทธ๋ฆฌ๋“œ ๋ ˆ์ด์•„์›ƒ */
.position-grid {
display: grid;
grid-template-columns: repeat(3, 1fr);
gap: 10px;
margin: 1.5em 0;
padding: 10px;
background: #f5f5f5;
border-radius: 12px;
}
/* ์ž…๋ ฅ ํ•„๋“œ ์Šคํƒ€์ผ๋ง */
input[type="text"], textarea {
border: 2px solid #e0e0e0;
border-radius: 8px;
padding: 12px;
font-size: 1.1em;
transition: all 0.3s ease;
}
input[type="text"]:focus, textarea:focus {
border-color: #2196F3;
box-shadow: 0 0 0 3px rgba(33,150,243,0.2);
}
/* ์Šฌ๋ผ์ด๋” ์Šคํƒ€์ผ๋ง */
.slider-container {
margin: 1.5em 0;
}
.slider {
height: 6px;
background: #e0e0e0;
border-radius: 3px;
}
.slider-handle {
width: 20px;
height: 20px;
background: #2196F3;
border: 2px solid white;
box-shadow: 0 2px 4px rgba(0,0,0,0.2);
}
/* ์ƒํƒœ ๋ฉ”์‹œ์ง€ */
.status-message {
padding: 10px;
border-radius: 8px;
margin: 10px 0;
font-size: 0.9em;
transition: all 0.3s ease;
}
.status-success {
background: #e8f5e9;
color: #2e7d32;
border: 1px solid #a5d6a7;
}
.status-error {
background: #ffebee;
color: #c62828;
border: 1px solid #ef9a9a;
}
/* ๋ฐ˜์‘ํ˜• ๋””์ž์ธ */
@media (max-width: 768px) {
.main-title h1 {
font-size: 2em;
}
.main-title p {
font-size: 1.1em;
}
.input-panel, .output-panel {
padding: 1em;
}
.position-btn {
width: 40px;
height: 40px;
font-size: 1em;
}
}
/* ์• ๋‹ˆ๋ฉ”์ด์…˜ ํšจ๊ณผ */
@keyframes fadeIn {
from {opacity: 0; transform: translateY(10px);}
to {opacity: 1; transform: translateY(0);}
}
.fade-in {
animation: fadeIn 0.3s ease-out;
}
"""
def add_text_with_stroke(draw, text, x, y, font, text_color, stroke_width):
"""Helper function to draw text with stroke"""
# Draw the stroke/outline
for adj_x in range(-stroke_width, stroke_width + 1):
for adj_y in range(-stroke_width, stroke_width + 1):
draw.text((x + adj_x, y + adj_y), text, font=font, fill=text_color)
def remove_background(image):
# Save the image to a specific location
filename = f"image_{uuid.uuid4()}.png" # Generates a universally unique identifier (UUID) for the filename
image.save(filename)
# Call gradio client for background removal
result = client.predict(images=handle_file(filename), api_name="/image")
return Image.open(result[0])
def superimpose(image_with_text, overlay_image):
# Open image as RGBA to handle transparency
overlay_image = overlay_image.convert("RGBA")
# Paste overlay on the background
image_with_text.paste(overlay_image, (0, 0), overlay_image)
# Save the final image
# image_with_text.save("output_image.png")
return image_with_text
def add_text_to_image(
input_image,
text,
font_size,
color,
opacity,
x_position,
y_position,
thickness,
text_position_type,
font_choice
):
try:
if input_image is None or text.strip() == "":
return input_image
# PIL Image ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜
if not isinstance(input_image, Image.Image):
if isinstance(input_image, np.ndarray):
image = Image.fromarray(input_image)
else:
raise ValueError("Unsupported image type")
else:
image = input_image.copy()
# ์ด๋ฏธ์ง€๋ฅผ RGBA ๋ชจ๋“œ๋กœ ๋ณ€ํ™˜
if image.mode != 'RGBA':
image = image.convert('RGBA')
# ํฐํŠธ ์„ค์ •
font_files = {
"Default": "DejaVuSans.ttf",
"Korean Regular": "ko-Regular.ttf"
}
try:
font_file = font_files.get(font_choice, "DejaVuSans.ttf")
font = ImageFont.truetype(font_file, int(font_size))
except Exception as e:
print(f"Font loading error ({font_choice}): {str(e)}")
font = ImageFont.load_default()
# ์ƒ‰์ƒ ์„ค์ •
color_map = {
'White': (255, 255, 255),
'Black': (0, 0, 0),
'Red': (255, 0, 0),
'Green': (0, 255, 0),
'Blue': (0, 0, 255),
'Yellow': (255, 255, 0),
'Purple': (128, 0, 128)
}
rgb_color = color_map.get(color, (255, 255, 255))
# ์ž„์‹œ Draw ๊ฐ์ฒด ์ƒ์„ฑํ•˜์—ฌ ํ…์ŠคํŠธ ํฌ๊ธฐ ๊ณ„์‚ฐ
temp_draw = ImageDraw.Draw(image)
text_bbox = temp_draw.textbbox((0, 0), text, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
# ์œ„์น˜ ๊ณ„์‚ฐ
actual_x = int((image.width - text_width) * (x_position / 100))
actual_y = int((image.height - text_height) * (y_position / 100))
# ํ…์ŠคํŠธ ์ƒ‰์ƒ ์„ค์ •
text_color = (*rgb_color, int(opacity))
if text_position_type == "Text Behind Image":
try:
# ์›๋ณธ ์ด๋ฏธ์ง€์—์„œ ์ „๊ฒฝ ๊ฐ์ฒด๋งŒ ์ถ”์ถœ
foreground = remove_background(image)
# ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์ƒ์„ฑ (์›๋ณธ ์ด๋ฏธ์ง€ ๋ณต์‚ฌ)
background = image.copy()
# ํ…์ŠคํŠธ๋ฅผ ๊ทธ๋ฆด ์ž„์‹œ ๋ ˆ์ด์–ด ์ƒ์„ฑ
text_layer = Image.new('RGBA', image.size, (255, 255, 255, 0))
draw_text = ImageDraw.Draw(text_layer)
# ํ…์ŠคํŠธ ๊ทธ๋ฆฌ๊ธฐ
add_text_with_stroke(
draw_text,
text,
actual_x,
actual_y,
font,
text_color,
int(thickness)
)
# ๋ฐฐ๊ฒฝ์— ํ…์ŠคํŠธ ํ•ฉ์„ฑ
background = Image.alpha_composite(background, text_layer)
# ํ…์ŠคํŠธ๊ฐ€ ์žˆ๋Š” ๋ฐฐ๊ฒฝ ์œ„์— ์ „๊ฒฝ ๊ฐ์ฒด ํ•ฉ์„ฑ
output_image = Image.alpha_composite(background, foreground)
except Exception as e:
print(f"Error in Text Behind Image processing: {str(e)}")
return input_image
else:
# ํ…์ŠคํŠธ ์˜ค๋ฒ„๋ ˆ์ด ์ƒ์„ฑ
txt_overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
draw = ImageDraw.Draw(txt_overlay)
# ํ…์ŠคํŠธ๋ฅผ ์ด๋ฏธ์ง€ ์œ„์— ๊ทธ๋ฆฌ๊ธฐ
add_text_with_stroke(
draw,
text,
actual_x,
actual_y,
font,
text_color,
int(thickness)
)
output_image = Image.alpha_composite(image, txt_overlay)
# RGB๋กœ ๋ณ€ํ™˜
output_image = output_image.convert('RGB')
return output_image
except Exception as e:
print(f"Error in add_text_to_image: {str(e)}")
return input_image
def update_position(new_position):
"""์œ„์น˜ ์—…๋ฐ์ดํŠธ ํ•จ์ˆ˜"""
print(f"Position updated to: {new_position}")
return new_position
def update_position_and_ui(pos):
"""์œ„์น˜ ์—…๋ฐ์ดํŠธ ๋ฐ UI ๋ฐ˜์˜"""
updates = {btn: gr.update(value="selected" if pos_val == pos else "")
for btn, pos_val in position_mapping.items()}
updates['position'] = pos
return [pos] + [updates[btn] for btn in position_mapping.keys()]
def process_inpainting_with_feedback(image, mask, prompt):
"""์ธํŽ˜์ธํŒ… ์ฒ˜๋ฆฌ ๋ฐ ํ”ผ๋“œ๋ฐฑ"""
try:
result = process_inpainting(image, mask, prompt)
return result, update_ui_state("inpainting", "Inpainting completed successfully!")
except Exception as e:
return None, update_ui_state("inpainting", f"Error: {str(e)}", is_error=True)
def update_controls(bg_prompt):
"""๋ฐฐ๊ฒฝ ํ”„๋กฌํ”„ํŠธ ์ž…๋ ฅ ์—ฌ๋ถ€์— ๋”ฐ๋ผ ์ปจํŠธ๋กค ํ‘œ์‹œ ์—…๋ฐ์ดํŠธ"""
is_visible = bool(bg_prompt)
return [
gr.update(visible=is_visible), # aspect_ratio
gr.update(visible=is_visible), # object_controls
]
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
position = gr.State(value="bottom-center")
processing_status = gr.State(value="idle")
gr.HTML("""
<div class="main-title">
<h1>๐ŸŽจ GiniGen Canvas-o3</h1>
<p>Remove background of specified objects, generate new backgrounds, and insert text over or behind images with prompts.</p>
</div>
""")
status_message = gr.HTML(
value='<div class="status-message"></div>',
visible=False
)
with gr.Row(equal_height=True):
# ์™ผ์ชฝ ํŒจ๋„ (์ž…๋ ฅ)
with gr.Column(scale=1):
with gr.Group(elem_classes="input-panel"):
with gr.Tabs():
# ์ฒซ ๋ฒˆ์งธ ํƒญ: ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ ๋ฐ ์ธํŽ˜์ธํŒ…
with gr.Tab("Image Upload & Inpainting"):
input_image = gr.Image(
type="pil",
label="Upload Image",
interactive=True,
height=400,
elem_classes="fade-in"
)
with gr.Group():
inpaint_prompt = gr.Textbox(
label="Inpainting Prompt",
placeholder="Describe what you want to add in the masked area..."
)
mask_input = image_annotator(
label="Draw mask for inpainting",
height=400
)
inpaint_btn = gr.Button("Apply Inpainting", variant="primary")
# ๋‘ ๋ฒˆ์งธ ํƒญ: ๋ฐฐ๊ฒฝ ์ œ๊ฑฐ ๋ฐ ์ƒ์„ฑ
with gr.Tab("Background Removal"):
text_prompt = gr.Textbox(
label="Object to Extract",
placeholder="Enter what you want to extract...",
interactive=True,
elem_classes="fade-in"
)
with gr.Row():
bg_prompt = gr.Textbox(
label="Background Prompt (optional)",
placeholder="Describe the background...",
interactive=True,
scale=3
)
aspect_ratio = gr.Dropdown(
choices=["1:1", "16:9", "9:16", "4:3"],
value="1:1",
label="Aspect Ratio",
interactive=True,
visible=True,
scale=1
)
with gr.Group(elem_classes="controls-panel", visible=False) as object_controls:
with gr.Column(scale=1):
with gr.Row():
btn_top_left = gr.Button("โ†–", elem_classes="position-btn")
btn_top_center = gr.Button("โ†‘", elem_classes="position-btn")
btn_top_right = gr.Button("โ†—", elem_classes="position-btn")
with gr.Row():
btn_middle_left = gr.Button("โ†", elem_classes="position-btn")
btn_middle_center = gr.Button("โ€ข", elem_classes="position-btn")
btn_middle_right = gr.Button("โ†’", elem_classes="position-btn")
with gr.Row():
btn_bottom_left = gr.Button("โ†™", elem_classes="position-btn")
btn_bottom_center = gr.Button("โ†“", elem_classes="position-btn", value="selected")
btn_bottom_right = gr.Button("โ†˜", elem_classes="position-btn")
with gr.Column(scale=1):
scale_slider = gr.Slider(
minimum=10,
maximum=200,
value=50,
step=5,
label="Object Size (%)"
)
process_btn = gr.Button(
"Process",
variant="primary",
interactive=False,
size="lg"
)
# ์˜ค๋ฅธ์ชฝ ํŒจ๋„ (์ถœ๋ ฅ)
with gr.Column(scale=1):
with gr.Group(elem_classes="output-panel"):
with gr.Tab("Result"):
combined_image = gr.Image(
label="Combined Result",
show_download_button=True,
type="pil",
height=400
)
with gr.Accordion("Text Insertion Options", open=False):
with gr.Group():
with gr.Row():
text_input = gr.Textbox(
label="Text Content",
placeholder="Enter text to add..."
)
text_position_type = gr.Radio(
choices=["Text Over Image", "Text Behind Image"],
value="Text Over Image",
label="Text Position"
)
with gr.Row():
with gr.Column(scale=1):
font_choice = gr.Dropdown(
choices=["Default", "Korean Regular"],
value="Default",
label="Font Selection",
interactive=True
)
font_size = gr.Slider(
minimum=10,
maximum=200,
value=40,
step=5,
label="Font Size"
)
color_dropdown = gr.Dropdown(
choices=["White", "Black", "Red", "Green", "Blue", "Yellow", "Purple"],
value="White",
label="Text Color"
)
thickness = gr.Slider(
minimum=0,
maximum=10,
value=1,
step=1,
label="Text Thickness"
)
with gr.Column(scale=1):
opacity_slider = gr.Slider(
minimum=0,
maximum=255,
value=255,
step=1,
label="Opacity"
)
x_position = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=1,
label="Left(0%)~Right(100%)"
)
y_position = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=1,
label="High(0%)~Low(100%)"
)
add_text_btn = gr.Button("Apply Text", variant="primary")
extracted_image = gr.Image(
label="Extracted Object",
show_download_button=True,
type="pil",
height=200
)
# CSS ์Šคํƒ€์ผ
gr.HTML("""
<style>
.position-btn.selected {
background-color: #2196F3 !important;
color: white !important;
}
</style>
""")
# ์ด๋ฒคํŠธ ๋ฐ”์ธ๋”ฉ
position_mapping = {
btn_top_left: "top-left",
btn_top_center: "top-center",
btn_top_right: "top-right",
btn_middle_left: "middle-left",
btn_middle_center: "middle-center",
btn_middle_right: "middle-right",
btn_bottom_left: "bottom-left",
btn_bottom_center: "bottom-center",
btn_bottom_right: "bottom-right"
}
def update_ui_state(component_id, value, is_error=False):
"""UI ์ƒํƒœ ์—…๋ฐ์ดํŠธ ์œ ํ‹ธ๋ฆฌํ‹ฐ"""
class_name = "status-error" if is_error else "status-success"
return gr.update(
value=f'<div class="status-message {class_name}">{value}</div>',
visible=True
)
# ์œ„์น˜ ๋ฒ„ํŠผ ์ด๋ฒคํŠธ ๋ฐ”์ธ๋”ฉ
for btn, pos in position_mapping.items():
btn.click(
fn=lambda p=pos: update_position_and_ui(p),
outputs=[position] + list(position_mapping.keys())
)
# ์ธํŽ˜์ธํŒ… ๋ฒ„ํŠผ ์ด๋ฒคํŠธ
inpaint_btn.click(
fn=process_inpainting_with_feedback,
inputs=[input_image, mask_input, inpaint_prompt],
outputs=[input_image, status_message]
)
# ํ”„๋กœ์„ธ์Šค ๋ฒ„ํŠผ ์ด๋ฒคํŠธ
process_btn.click(
fn=process_prompt,
inputs=[
input_image,
text_prompt,
bg_prompt,
aspect_ratio,
position,
scale_slider
],
outputs=[combined_image, extracted_image]
)
for btn, pos in position_mapping.items():
btn.click(
fn=lambda pos=pos: update_position(pos),
outputs=position
)
bg_prompt.change(
fn=update_controls,
inputs=bg_prompt,
outputs=[aspect_ratio, object_controls],
queue=False
)
input_image.change(
fn=update_process_button,
inputs=[input_image, text_prompt],
outputs=process_btn,
queue=False
)
text_prompt.change(
fn=update_process_button,
inputs=[input_image, text_prompt],
outputs=process_btn,
queue=False
)
add_text_btn.click(
fn=add_text_to_image,
inputs=[
combined_image,
text_input,
font_size,
color_dropdown,
opacity_slider,
x_position,
y_position,
thickness,
text_position_type,
font_choice
],
outputs=combined_image,
api_name="add_text"
)
demo.queue(max_size=5)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
max_threads=2
)