ginipick commited on
Commit
741d18d
ยท
verified ยท
1 Parent(s): 70d01f4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +863 -0
app.py ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import time
3
+ from collections.abc import Sequence
4
+ from typing import Any, cast
5
+ import os
6
+ from huggingface_hub import login, hf_hub_download
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import pillow_heif
11
+ import spaces
12
+ import torch
13
+ from gradio_image_annotation import image_annotator
14
+ from gradio_imageslider import ImageSlider
15
+ from PIL import Image
16
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
17
+ from refiners.fluxion.utils import no_grad
18
+ from refiners.solutions import BoxSegmenter
19
+ from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor
20
+ from diffusers import FluxPipeline
21
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
22
+ import gc
23
+
24
+ from PIL import Image, ImageDraw, ImageFont
25
+ from PIL import Image
26
+ from gradio_client import Client, handle_file
27
+ import uuid
28
+
29
+
30
+ def clear_memory():
31
+ """๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜"""
32
+ gc.collect()
33
+ try:
34
+ if torch.cuda.is_available():
35
+ with torch.cuda.device(0): # ๋ช…์‹œ์ ์œผ๋กœ device 0 ์‚ฌ์šฉ
36
+ torch.cuda.empty_cache()
37
+ except:
38
+ pass
39
+
40
+ # GPU ์„ค์ •
41
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
42
+
43
+ # GPU ์„ค์ •์„ try-except๋กœ ๊ฐ์‹ธ๊ธฐ
44
+ if torch.cuda.is_available():
45
+ try:
46
+ with torch.cuda.device(0):
47
+ torch.cuda.empty_cache()
48
+ torch.backends.cudnn.benchmark = True
49
+ torch.backends.cuda.matmul.allow_tf32 = True
50
+ except:
51
+ print("Warning: Could not configure CUDA settings")
52
+
53
+ # ๋ฒˆ์—ญ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
54
+ model_name = "Helsinki-NLP/opus-mt-ko-en"
55
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
56
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to('cpu')
57
+ translator = pipeline("translation", model=model, tokenizer=tokenizer, device=-1)
58
+
59
+ def translate_to_english(text: str) -> str:
60
+ """ํ•œ๊ธ€ ํ…์ŠคํŠธ๋ฅผ ์˜์–ด๋กœ ๋ฒˆ์—ญ"""
61
+ try:
62
+ if any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text):
63
+ translated = translator(text, max_length=128)[0]['translation_text']
64
+ print(f"Translated '{text}' to '{translated}'")
65
+ return translated
66
+ return text
67
+ except Exception as e:
68
+ print(f"Translation error: {str(e)}")
69
+ return text
70
+
71
+ BoundingBox = tuple[int, int, int, int]
72
+
73
+ pillow_heif.register_heif_opener()
74
+ pillow_heif.register_avif_opener()
75
+
76
+ # HF ํ† ํฐ ์„ค์ •
77
+ HF_TOKEN = os.getenv("HF_TOKEN")
78
+ if HF_TOKEN is None:
79
+ raise ValueError("Please set the HF_TOKEN environment variable")
80
+
81
+ try:
82
+ login(token=HF_TOKEN)
83
+ except Exception as e:
84
+ raise ValueError(f"Failed to login to Hugging Face: {str(e)}")
85
+
86
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
87
+ segmenter = BoxSegmenter(device="cpu")
88
+ segmenter.device = device
89
+ segmenter.model = segmenter.model.to(device=segmenter.device)
90
+
91
+ gd_model_path = "IDEA-Research/grounding-dino-base"
92
+ gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
93
+ gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_dtype=torch.float32)
94
+ gd_model = gd_model.to(device=device)
95
+ assert isinstance(gd_model, GroundingDinoForObjectDetection)
96
+
97
+ # FLUX ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
98
+ pipe = FluxPipeline.from_pretrained(
99
+ "black-forest-labs/FLUX.1-dev",
100
+ torch_dtype=torch.float16,
101
+ use_auth_token=HF_TOKEN
102
+ )
103
+ pipe.enable_attention_slicing(slice_size="auto")
104
+
105
+ # LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
106
+ pipe.load_lora_weights(
107
+ hf_hub_download(
108
+ "ByteDance/Hyper-SD",
109
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors",
110
+ use_auth_token=HF_TOKEN
111
+ )
112
+ )
113
+ pipe.fuse_lora(lora_scale=0.125)
114
+
115
+ # GPU ์„ค์ •์„ try-except๋กœ ๊ฐ์‹ธ๊ธฐ
116
+ try:
117
+ if torch.cuda.is_available():
118
+ pipe = pipe.to("cuda:0") # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
119
+ except Exception as e:
120
+ print(f"Warning: Could not move pipeline to CUDA: {str(e)}")
121
+
122
+ client = Client("NabeelShar/BiRefNet_for_text_writing")
123
+
124
+ class timer:
125
+ def __init__(self, method_name="timed process"):
126
+ self.method = method_name
127
+ def __enter__(self):
128
+ self.start = time.time()
129
+ print(f"{self.method} starts")
130
+ def __exit__(self, exc_type, exc_val, exc_tb):
131
+ end = time.time()
132
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
133
+
134
+ def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
135
+ if not bboxes:
136
+ return None
137
+ for bbox in bboxes:
138
+ assert len(bbox) == 4
139
+ assert all(isinstance(x, int) for x in bbox)
140
+ return (
141
+ min(bbox[0] for bbox in bboxes),
142
+ min(bbox[1] for bbox in bboxes),
143
+ max(bbox[2] for bbox in bboxes),
144
+ max(bbox[3] for bbox in bboxes),
145
+ )
146
+
147
+ def corners_to_pixels_format(bboxes: torch.Tensor, width: int, height: int) -> torch.Tensor:
148
+ x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
149
+ return torch.stack((x1.clamp_(0, width), y1.clamp_(0, height), x2.clamp_(0, width), y2.clamp_(0, height)), dim=-1)
150
+
151
+ def gd_detect(img: Image.Image, prompt: str) -> BoundingBox | None:
152
+ inputs = gd_processor(images=img, text=f"{prompt}.", return_tensors="pt").to(device=device)
153
+ with no_grad():
154
+ outputs = gd_model(**inputs)
155
+ width, height = img.size
156
+ results: dict[str, Any] = gd_processor.post_process_grounded_object_detection(
157
+ outputs,
158
+ inputs["input_ids"],
159
+ target_sizes=[(height, width)],
160
+ )[0]
161
+ assert "boxes" in results and isinstance(results["boxes"], torch.Tensor)
162
+ bboxes = corners_to_pixels_format(results["boxes"].cpu(), width, height)
163
+ return bbox_union(bboxes.numpy().tolist())
164
+
165
+ def apply_mask(img: Image.Image, mask_img: Image.Image, defringe: bool = True) -> Image.Image:
166
+ assert img.size == mask_img.size
167
+ img = img.convert("RGB")
168
+ mask_img = mask_img.convert("L")
169
+ if defringe:
170
+ rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0
171
+ foreground = cast(np.ndarray[Any, np.dtype[np.uint8]], estimate_foreground_ml(rgb, alpha))
172
+ img = Image.fromarray((foreground * 255).astype("uint8"))
173
+ result = Image.new("RGBA", img.size)
174
+ result.paste(img, (0, 0), mask_img)
175
+ return result
176
+
177
+
178
+ def adjust_size_to_multiple_of_8(width: int, height: int) -> tuple[int, int]:
179
+ """์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ฅผ 8์˜ ๋ฐฐ์ˆ˜๋กœ ์กฐ์ •ํ•˜๋Š” ํ•จ์ˆ˜"""
180
+ new_width = ((width + 7) // 8) * 8
181
+ new_height = ((height + 7) // 8) * 8
182
+ return new_width, new_height
183
+
184
+ def calculate_dimensions(aspect_ratio: str, base_size: int = 512) -> tuple[int, int]:
185
+ """์„ ํƒ๋œ ๋น„์œจ์— ๋”ฐ๋ผ ์ด๋ฏธ์ง€ ํฌ๊ธฐ ๊ณ„์‚ฐ"""
186
+ if aspect_ratio == "1:1":
187
+ return base_size, base_size
188
+ elif aspect_ratio == "16:9":
189
+ return base_size * 16 // 9, base_size
190
+ elif aspect_ratio == "9:16":
191
+ return base_size, base_size * 16 // 9
192
+ elif aspect_ratio == "4:3":
193
+ return base_size * 4 // 3, base_size
194
+ return base_size, base_size
195
+
196
+ @spaces.GPU(duration=20) # 40์ดˆ์—์„œ 20์ดˆ๋กœ ๊ฐ์†Œ
197
+ def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
198
+ try:
199
+ width, height = calculate_dimensions(aspect_ratio)
200
+ width, height = adjust_size_to_multiple_of_8(width, height)
201
+
202
+ max_size = 768
203
+ if width > max_size or height > max_size:
204
+ ratio = max_size / max(width, height)
205
+ width = int(width * ratio)
206
+ height = int(height * ratio)
207
+ width, height = adjust_size_to_multiple_of_8(width, height)
208
+
209
+ with timer("Background generation"):
210
+ try:
211
+ with torch.inference_mode():
212
+ image = pipe(
213
+ prompt=prompt,
214
+ width=width,
215
+ height=height,
216
+ num_inference_steps=8,
217
+ guidance_scale=4.0
218
+ ).images[0]
219
+ except Exception as e:
220
+ print(f"Pipeline error: {str(e)}")
221
+ return Image.new('RGB', (width, height), 'white')
222
+
223
+ return image
224
+ except Exception as e:
225
+ print(f"Background generation error: {str(e)}")
226
+ return Image.new('RGB', (512, 512), 'white')
227
+
228
+ def create_position_grid():
229
+ return """
230
+ <div class="position-grid" style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px; width: 150px; margin: auto;">
231
+ <button class="position-btn" data-pos="top-left">โ†–</button>
232
+ <button class="position-btn" data-pos="top-center">โ†‘</button>
233
+ <button class="position-btn" data-pos="top-right">โ†—</button>
234
+ <button class="position-btn" data-pos="middle-left">โ†</button>
235
+ <button class="position-btn" data-pos="middle-center">โ€ข</button>
236
+ <button class="position-btn" data-pos="middle-right">โ†’</button>
237
+ <button class="position-btn" data-pos="bottom-left">โ†™</button>
238
+ <button class="position-btn" data-pos="bottom-center" data-default="true">โ†“</button>
239
+ <button class="position-btn" data-pos="bottom-right">โ†˜</button>
240
+ </div>
241
+ """
242
+
243
+ def calculate_object_position(position: str, bg_size: tuple[int, int], obj_size: tuple[int, int]) -> tuple[int, int]:
244
+ """์˜ค๋ธŒ์ ํŠธ์˜ ์œ„์น˜ ๊ณ„์‚ฐ"""
245
+ bg_width, bg_height = bg_size
246
+ obj_width, obj_height = obj_size
247
+
248
+ positions = {
249
+ "top-left": (0, 0),
250
+ "top-center": ((bg_width - obj_width) // 2, 0),
251
+ "top-right": (bg_width - obj_width, 0),
252
+ "middle-left": (0, (bg_height - obj_height) // 2),
253
+ "middle-center": ((bg_width - obj_width) // 2, (bg_height - obj_height) // 2),
254
+ "middle-right": (bg_width - obj_width, (bg_height - obj_height) // 2),
255
+ "bottom-left": (0, bg_height - obj_height),
256
+ "bottom-center": ((bg_width - obj_width) // 2, bg_height - obj_height),
257
+ "bottom-right": (bg_width - obj_width, bg_height - obj_height)
258
+ }
259
+
260
+ return positions.get(position, positions["bottom-center"])
261
+
262
+ def resize_object(image: Image.Image, scale_percent: float) -> Image.Image:
263
+ """์˜ค๋ธŒ์ ํŠธ ํฌ๊ธฐ ์กฐ์ •"""
264
+ width = int(image.width * scale_percent / 100)
265
+ height = int(image.height * scale_percent / 100)
266
+ return image.resize((width, height), Image.Resampling.LANCZOS)
267
+
268
+ def combine_with_background(foreground: Image.Image, background: Image.Image,
269
+ position: str = "bottom-center", scale_percent: float = 100) -> Image.Image:
270
+ """์ „๊ฒฝ๊ณผ ๋ฐฐ๊ฒฝ ํ•ฉ์„ฑ ํ•จ์ˆ˜"""
271
+ # ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์ค€๋น„
272
+ result = background.convert('RGBA')
273
+
274
+ # ์˜ค๋ธŒ์ ํŠธ ํฌ๊ธฐ ์กฐ์ •
275
+ scaled_foreground = resize_object(foreground, scale_percent)
276
+
277
+ # ์˜ค๋ธŒ์ ํŠธ ์œ„์น˜ ๊ณ„์‚ฐ
278
+ x, y = calculate_object_position(position, result.size, scaled_foreground.size)
279
+
280
+ # ํ•ฉ์„ฑ
281
+ result.paste(scaled_foreground, (x, y), scaled_foreground)
282
+ return result
283
+
284
+ @spaces.GPU(duration=30) # 120์ดˆ์—์„œ 30์ดˆ๋กœ ๊ฐ์†Œ
285
+ def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Image.Image, BoundingBox | None, list[str]]:
286
+ time_log: list[str] = []
287
+ try:
288
+ if isinstance(prompt, str):
289
+ t0 = time.time()
290
+ bbox = gd_detect(img, prompt)
291
+ time_log.append(f"detect: {time.time() - t0}")
292
+ if not bbox:
293
+ print(time_log[0])
294
+ raise gr.Error("No object detected")
295
+ else:
296
+ bbox = prompt
297
+ t0 = time.time()
298
+ mask = segmenter(img, bbox)
299
+ time_log.append(f"segment: {time.time() - t0}")
300
+ return mask, bbox, time_log
301
+ except Exception as e:
302
+ print(f"GPU process error: {str(e)}")
303
+ raise
304
+
305
+ def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str | None = None, aspect_ratio: str = "1:1") -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]:
306
+ try:
307
+ # ์ž…๋ ฅ ์ด๋ฏธ์ง€ ํฌ๊ธฐ ์ œํ•œ
308
+ max_size = 1024
309
+ if img.width > max_size or img.height > max_size:
310
+ ratio = max_size / max(img.width, img.height)
311
+ new_size = (int(img.width * ratio), int(img.height * ratio))
312
+ img = img.resize(new_size, Image.LANCZOS)
313
+
314
+ # CUDA ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ์ˆ˜์ •
315
+ try:
316
+ if torch.cuda.is_available():
317
+ current_device = torch.cuda.current_device()
318
+ with torch.cuda.device(current_device):
319
+ torch.cuda.empty_cache()
320
+ except Exception as e:
321
+ print(f"CUDA memory management failed: {e}")
322
+
323
+ with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
324
+ mask, bbox, time_log = _gpu_process(img, prompt)
325
+ masked_alpha = apply_mask(img, mask, defringe=True)
326
+
327
+ if bg_prompt:
328
+ background = generate_background(bg_prompt, aspect_ratio)
329
+ combined = background
330
+ else:
331
+ combined = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
332
+
333
+ clear_memory()
334
+
335
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp:
336
+ combined.save(temp.name)
337
+ return (img, combined, masked_alpha), gr.DownloadButton(value=temp.name, interactive=True)
338
+ except Exception as e:
339
+ clear_memory()
340
+ print(f"Processing error: {str(e)}")
341
+ raise gr.Error(f"Processing failed: {str(e)}")
342
+
343
+ def on_change_bbox(prompts: dict[str, Any] | None):
344
+ return gr.update(interactive=prompts is not None)
345
+
346
+
347
+ def on_change_prompt(img: Image.Image | None, prompt: str | None, bg_prompt: str | None = None):
348
+ return gr.update(interactive=bool(img and prompt))
349
+
350
+
351
+
352
+ def process_prompt(img: Image.Image, prompt: str, bg_prompt: str | None = None,
353
+ aspect_ratio: str = "1:1", position: str = "bottom-center",
354
+ scale_percent: float = 100) -> tuple[Image.Image, Image.Image]:
355
+ try:
356
+ if img is None or prompt.strip() == "":
357
+ raise gr.Error("Please provide both image and prompt")
358
+
359
+ print(f"Processing with position: {position}, scale: {scale_percent}")
360
+
361
+ try:
362
+ prompt = translate_to_english(prompt)
363
+ if bg_prompt:
364
+ bg_prompt = translate_to_english(bg_prompt)
365
+ except Exception as e:
366
+ print(f"Translation error (continuing with original text): {str(e)}")
367
+
368
+ results, _ = _process(img, prompt, bg_prompt, aspect_ratio)
369
+
370
+ if bg_prompt:
371
+ try:
372
+ combined = combine_with_background(
373
+ foreground=results[2],
374
+ background=results[1],
375
+ position=position,
376
+ scale_percent=scale_percent
377
+ )
378
+ print(f"Combined image created with position: {position}")
379
+ return combined, results[2]
380
+ except Exception as e:
381
+ print(f"Combination error: {str(e)}")
382
+ return results[1], results[2]
383
+
384
+ return results[1], results[2]
385
+ except Exception as e:
386
+ print(f"Error in process_prompt: {str(e)}")
387
+ raise gr.Error(str(e))
388
+ finally:
389
+ clear_memory()
390
+
391
+ def process_bbox(img: Image.Image, box_input: str) -> tuple[Image.Image, Image.Image]:
392
+ try:
393
+ if img is None or box_input.strip() == "":
394
+ raise gr.Error("Please provide both image and bounding box coordinates")
395
+
396
+ try:
397
+ coords = eval(box_input)
398
+ if not isinstance(coords, list) or len(coords) != 4:
399
+ raise ValueError("Invalid box format")
400
+ bbox = tuple(int(x) for x in coords)
401
+ except:
402
+ raise gr.Error("Invalid box format. Please provide [xmin, ymin, xmax, ymax]")
403
+
404
+ # Process the image
405
+ results, _ = _process(img, bbox)
406
+
407
+ # ํ•ฉ์„ฑ๋œ ์ด๋ฏธ์ง€์™€ ์ถ”์ถœ๋œ ์ด๋ฏธ์ง€๋งŒ ๋ฐ˜ํ™˜
408
+ return results[1], results[2]
409
+ except Exception as e:
410
+ raise gr.Error(str(e))
411
+
412
+ # Event handler functions ์ˆ˜์ •
413
+ def update_process_button(img, prompt):
414
+ return gr.update(
415
+ interactive=bool(img and prompt),
416
+ variant="primary" if bool(img and prompt) else "secondary"
417
+ )
418
+
419
+ def update_box_button(img, box_input):
420
+ try:
421
+ if img and box_input:
422
+ coords = eval(box_input)
423
+ if isinstance(coords, list) and len(coords) == 4:
424
+ return gr.update(interactive=True, variant="primary")
425
+ return gr.update(interactive=False, variant="secondary")
426
+ except:
427
+ return gr.update(interactive=False, variant="secondary")
428
+
429
+
430
+ # CSS ์ •์˜
431
+ css = """
432
+ footer {display: none}
433
+ .main-title {
434
+ text-align: center;
435
+ margin: 2em 0;
436
+ padding: 1em;
437
+ background: #f7f7f7;
438
+ border-radius: 10px;
439
+ }
440
+ .main-title h1 {
441
+ color: #2196F3;
442
+ font-size: 2.5em;
443
+ margin-bottom: 0.5em;
444
+ }
445
+ .main-title p {
446
+ color: #666;
447
+ font-size: 1.2em;
448
+ }
449
+ .container {
450
+ max-width: 1200px;
451
+ margin: auto;
452
+ padding: 20px;
453
+ }
454
+ .tabs {
455
+ margin-top: 1em;
456
+ }
457
+ .input-group {
458
+ background: white;
459
+ padding: 1em;
460
+ border-radius: 8px;
461
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
462
+ }
463
+ .output-group {
464
+ background: white;
465
+ padding: 1em;
466
+ border-radius: 8px;
467
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
468
+ }
469
+ button.primary {
470
+ background: #2196F3;
471
+ border: none;
472
+ color: white;
473
+ padding: 0.5em 1em;
474
+ border-radius: 4px;
475
+ cursor: pointer;
476
+ transition: background 0.3s ease;
477
+ }
478
+ button.primary:hover {
479
+ background: #1976D2;
480
+ }
481
+ .position-btn {
482
+ transition: all 0.3s ease;
483
+ }
484
+ .position-btn:hover {
485
+ background-color: #e3f2fd;
486
+ }
487
+ .position-btn.selected {
488
+ background-color: #2196F3;
489
+ color: white;
490
+ }
491
+ """
492
+
493
+
494
+
495
+ def add_text_with_stroke(draw, text, x, y, font, text_color, stroke_width):
496
+ """Helper function to draw text with stroke"""
497
+ # Draw the stroke/outline
498
+ for adj_x in range(-stroke_width, stroke_width + 1):
499
+ for adj_y in range(-stroke_width, stroke_width + 1):
500
+ draw.text((x + adj_x, y + adj_y), text, font=font, fill=text_color)
501
+
502
+ def remove_background(image):
503
+ # Save the image to a specific location
504
+ filename = f"image_{uuid.uuid4()}.png" # Generates a universally unique identifier (UUID) for the filename
505
+ image.save(filename)
506
+ # Call gradio client for background removal
507
+ result = client.predict(images=handle_file(filename), api_name="/image")
508
+ return Image.open(result[0])
509
+
510
+ def superimpose(image_with_text, overlay_image):
511
+ # Open image as RGBA to handle transparency
512
+ overlay_image = overlay_image.convert("RGBA")
513
+ # Paste overlay on the background
514
+ image_with_text.paste(overlay_image, (0, 0), overlay_image)
515
+ # Save the final image
516
+ # image_with_text.save("output_image.png")
517
+ return image_with_text
518
+
519
+ def add_text_to_image(
520
+ input_image,
521
+ text,
522
+ font_size,
523
+ color,
524
+ opacity,
525
+ x_position,
526
+ y_position,
527
+ thickness,
528
+ text_position_type,
529
+ font_choice # ์ƒˆ๋กœ์šด ํŒŒ๋ผ๋ฏธํ„ฐ ์ถ”๊ฐ€
530
+ ):
531
+ """
532
+ Add text to an image with customizable properties
533
+ """
534
+ try:
535
+ if input_image is None:
536
+ return None
537
+
538
+ # PIL Image ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜
539
+ if not isinstance(input_image, Image.Image):
540
+ if isinstance(input_image, np.ndarray):
541
+ image = Image.fromarray(input_image)
542
+ else:
543
+ raise ValueError("Unsupported image type")
544
+ else:
545
+ image = input_image.copy()
546
+
547
+ # ์ด๋ฏธ์ง€๋ฅผ RGBA ๋ชจ๋“œ๋กœ ๋ณ€ํ™˜
548
+ if image.mode != 'RGBA':
549
+ image = image.convert('RGBA')
550
+
551
+ # Text Behind Image ์ฒ˜๋ฆฌ
552
+ if text_position_type == "Text Behind Image":
553
+ # ์›๋ณธ ์ด๋ฏธ์ง€์˜ ๋ฐฐ๊ฒฝ ์ œ๊ฑฐ
554
+ overlay_image = remove_background(image)
555
+
556
+ # ํ…์ŠคํŠธ ์˜ค๋ฒ„๋ ˆ์ด ์ƒ์„ฑ
557
+ txt_overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
558
+ draw = ImageDraw.Draw(txt_overlay)
559
+
560
+ # ํฐํŠธ ์„ค์ •
561
+ font_files = {
562
+ "Default": "DejaVuSans.ttf",
563
+ "Korean Regular": "ko-Regular.ttf",
564
+ "Korean Son": "ko-son.ttf"
565
+ }
566
+
567
+ try:
568
+ font_file = font_files.get(font_choice, "DejaVuSans.ttf")
569
+ font = ImageFont.truetype(font_file, int(font_size))
570
+ except Exception as e:
571
+ print(f"Font loading error ({font_choice}): {str(e)}")
572
+ try:
573
+ font = ImageFont.truetype("arial.ttf", int(font_size))
574
+ except:
575
+ print("Using default font")
576
+ font = ImageFont.load_default()
577
+
578
+ # ์ƒ‰์ƒ ์„ค์ •
579
+ color_map = {
580
+ 'White': (255, 255, 255),
581
+ 'Black': (0, 0, 0),
582
+ 'Red': (255, 0, 0),
583
+ 'Green': (0, 255, 0),
584
+ 'Blue': (0, 0, 255),
585
+ 'Yellow': (255, 255, 0),
586
+ 'Purple': (128, 0, 128)
587
+ }
588
+ rgb_color = color_map.get(color, (255, 255, 255))
589
+
590
+ # ํ…์ŠคํŠธ ํฌ๊ธฐ ๊ณ„์‚ฐ
591
+ text_bbox = draw.textbbox((0, 0), text, font=font)
592
+ text_width = text_bbox[2] - text_bbox[0]
593
+ text_height = text_bbox[3] - text_bbox[1]
594
+
595
+ # ์œ„์น˜ ๊ณ„์‚ฐ
596
+ actual_x = int((image.width - text_width) * (x_position / 100))
597
+ actual_y = int((image.height - text_height) * (y_position / 100))
598
+
599
+ # ํ…์ŠคํŠธ ์ƒ‰์ƒ ์„ค์ •
600
+ text_color = (*rgb_color, int(opacity))
601
+
602
+ # ํ…์ŠคํŠธ ๊ทธ๋ฆฌ๊ธฐ
603
+ add_text_with_stroke(
604
+ draw,
605
+ text,
606
+ actual_x,
607
+ actual_y,
608
+ font,
609
+ text_color,
610
+ int(thickness)
611
+ )
612
+
613
+ if text_position_type == "Text Behind Image":
614
+ # ํ…์ŠคํŠธ๋ฅผ ๋จผ์ € ๊ทธ๋ฆฌ๊ณ  ๊ทธ ์œ„์— ์ด๋ฏธ์ง€ ์˜ค๋ฒ„๋ ˆ์ด
615
+ output_image = Image.alpha_composite(image, txt_overlay)
616
+ output_image = superimpose(output_image, overlay_image)
617
+ else:
618
+ # ๊ธฐ์กด ๋ฐฉ์‹๋Œ€๋กœ ํ…์ŠคํŠธ๋ฅผ ์ด๋ฏธ์ง€ ์œ„์— ๊ทธ๋ฆฌ๊ธฐ
619
+ output_image = Image.alpha_composite(image, txt_overlay)
620
+
621
+ # RGB๋กœ ๋ณ€ํ™˜
622
+ output_image = output_image.convert('RGB')
623
+
624
+ return output_image
625
+
626
+ except Exception as e:
627
+ print(f"Error in add_text_to_image: {str(e)}")
628
+ return input_image
629
+
630
+
631
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
632
+ gr.HTML("""
633
+ <div class="main-title">
634
+ <h1>๐ŸŽจGiniGen Canvas-o3</h1>
635
+ <p>Remove background of specified objects, generate new backgrounds, and insert text over or behind images with prompts.</p>
636
+ </div>
637
+ """)
638
+
639
+ with gr.Row():
640
+ with gr.Column(scale=1):
641
+ input_image = gr.Image(
642
+ type="pil",
643
+ label="Upload Image",
644
+ interactive=True
645
+ )
646
+ text_prompt = gr.Textbox(
647
+ label="Object to Extract",
648
+ placeholder="Enter what you want to extract...",
649
+ interactive=True
650
+ )
651
+ with gr.Row():
652
+ bg_prompt = gr.Textbox(
653
+ label="Background Prompt (optional)",
654
+ placeholder="Describe the background...",
655
+ interactive=True,
656
+ scale=3
657
+ )
658
+ aspect_ratio = gr.Dropdown(
659
+ choices=["1:1", "16:9", "9:16", "4:3"],
660
+ value="1:1",
661
+ label="Aspect Ratio",
662
+ interactive=True,
663
+ visible=True,
664
+ scale=1
665
+ )
666
+
667
+ with gr.Row(visible=False) as object_controls:
668
+ with gr.Column(scale=1):
669
+ with gr.Row():
670
+ position = gr.State(value="bottom-center")
671
+ btn_top_left = gr.Button("โ†–")
672
+ btn_top_center = gr.Button("โ†‘")
673
+ btn_top_right = gr.Button("โ†—")
674
+ with gr.Row():
675
+ btn_middle_left = gr.Button("โ†")
676
+ btn_middle_center = gr.Button("โ€ข")
677
+ btn_middle_right = gr.Button("โ†’")
678
+ with gr.Row():
679
+ btn_bottom_left = gr.Button("โ†™")
680
+ btn_bottom_center = gr.Button("โ†“")
681
+ btn_bottom_right = gr.Button("โ†˜")
682
+ with gr.Column(scale=1):
683
+ scale_slider = gr.Slider(
684
+ minimum=10,
685
+ maximum=200,
686
+ value=50,
687
+ step=5,
688
+ label="Object Size (%)"
689
+ )
690
+
691
+ process_btn = gr.Button(
692
+ "Process",
693
+ variant="primary",
694
+ interactive=False
695
+ )
696
+
697
+ with gr.Column(scale=1):
698
+ with gr.Tab("Result"):
699
+ combined_image = gr.Image(
700
+ label="Combined Result",
701
+ show_download_button=True,
702
+ type="pil",
703
+ height=512
704
+ )
705
+
706
+ # ํ…์ŠคํŠธ ์‚ฝ์ž… ์ปจํŠธ๋กค์„ ๋” ๋ช…ํ™•ํ•˜๊ฒŒ ๊ตฌ๋ถ„
707
+ with gr.Group():
708
+ gr.Markdown("### Add Text to Image")
709
+ with gr.Row():
710
+ text_input = gr.Textbox(
711
+ label="Text Content",
712
+ placeholder="Enter text to add to image..."
713
+ )
714
+ text_position_type = gr.Radio(
715
+ choices=["Text Over Image", "Text Behind Image"],
716
+ value="Text Over Image",
717
+ label="Text Position Type",
718
+ interactive=True
719
+ )
720
+
721
+ with gr.Row():
722
+ with gr.Column(scale=1):
723
+ # ํฐํŠธ ์„ ํƒ Dropdown ์ถ”๊ฐ€
724
+ font_choice = gr.Dropdown(
725
+ choices=["Default", "Korean Regular", "Korean Son"],
726
+ value="Default",
727
+ label="Font Selection",
728
+ interactive=True
729
+ )
730
+ font_size = gr.Slider(
731
+ minimum=10,
732
+ maximum=200,
733
+ value=40,
734
+ step=5,
735
+ label="Font Size"
736
+ )
737
+ color_dropdown = gr.Dropdown(
738
+ choices=["White", "Black", "Red", "Green", "Blue", "Yellow", "Purple"],
739
+ value="White",
740
+ label="Text Color"
741
+ )
742
+ thickness = gr.Slider(
743
+ minimum=0,
744
+ maximum=10,
745
+ value=1,
746
+ step=1,
747
+ label="Text Thickness"
748
+ )
749
+ with gr.Column(scale=1):
750
+ opacity_slider = gr.Slider(
751
+ minimum=0,
752
+ maximum=255,
753
+ value=255,
754
+ step=1,
755
+ label="Opacity"
756
+ )
757
+ x_position = gr.Slider(
758
+ minimum=0,
759
+ maximum=100,
760
+ value=50,
761
+ step=1,
762
+ label="X Position (%)"
763
+ )
764
+ y_position = gr.Slider(
765
+ minimum=0,
766
+ maximum=100,
767
+ value=50,
768
+ step=1,
769
+ label="Y Position (%)"
770
+ )
771
+ add_text_btn = gr.Button("Apply Text", variant="primary")
772
+
773
+ with gr.Row():
774
+ extracted_image = gr.Image(
775
+ label="Extracted Object",
776
+ show_download_button=True,
777
+ type="pil",
778
+ height=256
779
+ )
780
+
781
+ # ๊ฐ ๋ฒ„ํŠผ์— ๋Œ€ํ•œ ํด๋ฆญ ์ด๋ฒคํŠธ ์ฒ˜๋ฆฌ
782
+ def update_position(new_position):
783
+ return new_position
784
+
785
+ btn_top_left.click(fn=lambda: update_position("top-left"), outputs=position)
786
+ btn_top_center.click(fn=lambda: update_position("top-center"), outputs=position)
787
+ btn_top_right.click(fn=lambda: update_position("top-right"), outputs=position)
788
+ btn_middle_left.click(fn=lambda: update_position("middle-left"), outputs=position)
789
+ btn_middle_center.click(fn=lambda: update_position("middle-center"), outputs=position)
790
+ btn_middle_right.click(fn=lambda: update_position("middle-right"), outputs=position)
791
+ btn_bottom_left.click(fn=lambda: update_position("bottom-left"), outputs=position)
792
+ btn_bottom_center.click(fn=lambda: update_position("bottom-center"), outputs=position)
793
+ btn_bottom_right.click(fn=lambda: update_position("bottom-right"), outputs=position)
794
+
795
+ # Event bindings
796
+ input_image.change(
797
+ fn=update_process_button,
798
+ inputs=[input_image, text_prompt],
799
+ outputs=process_btn,
800
+ queue=False
801
+ )
802
+
803
+ text_prompt.change(
804
+ fn=update_process_button,
805
+ inputs=[input_image, text_prompt],
806
+ outputs=process_btn,
807
+ queue=False
808
+ )
809
+
810
+ def update_controls(bg_prompt):
811
+ """๋ฐฐ๊ฒฝ ํ”„๋กฌํ”„ํŠธ ์ž…๋ ฅ ์—ฌ๋ถ€์— ๋”ฐ๋ผ ์ปจํŠธ๋กค ํ‘œ์‹œ ์—…๋ฐ์ดํŠธ"""
812
+ is_visible = bool(bg_prompt)
813
+ return [
814
+ gr.update(visible=is_visible), # aspect_ratio
815
+ gr.update(visible=is_visible), # object_controls
816
+ ]
817
+
818
+ bg_prompt.change(
819
+ fn=update_controls,
820
+ inputs=bg_prompt,
821
+ outputs=[aspect_ratio, object_controls],
822
+ queue=False
823
+ )
824
+
825
+ process_btn.click(
826
+ fn=process_prompt,
827
+ inputs=[
828
+ input_image,
829
+ text_prompt,
830
+ bg_prompt,
831
+ aspect_ratio,
832
+ position,
833
+ scale_slider
834
+ ],
835
+ outputs=[combined_image, extracted_image],
836
+ queue=True
837
+ )
838
+
839
+ # ํ…์ŠคํŠธ ์ถ”๊ฐ€ ๋ฒ„ํŠผ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ ์ˆ˜์ •
840
+ add_text_btn.click(
841
+ fn=add_text_to_image,
842
+ inputs=[
843
+ combined_image,
844
+ text_input,
845
+ font_size,
846
+ color_dropdown,
847
+ opacity_slider,
848
+ x_position,
849
+ y_position,
850
+ thickness,
851
+ text_position_type,
852
+ font_choice # ์ƒˆ๋กœ์šด ์ž…๋ ฅ ์ถ”๊ฐ€
853
+ ],
854
+ outputs=combined_image
855
+ )
856
+
857
+ demo.queue(max_size=5) # ํ ํฌ๊ธฐ ์ œํ•œ
858
+ demo.launch(
859
+ server_name="0.0.0.0",
860
+ server_port=7860,
861
+ share=False,
862
+ max_threads=2 # ์Šค๋ ˆ๋“œ ์ˆ˜ ์ œํ•œ
863
+ )