ginipick commited on
Commit
38dd8f0
·
verified ·
1 Parent(s): 2447d97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -0
app.py CHANGED
@@ -1,3 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Gradio UI 부분 수정
2
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
3
  gr.HTML("""
 
1
+ import tempfile
2
+ import time
3
+ from collections.abc import Sequence
4
+ from typing import Any, cast
5
+ import os
6
+ from huggingface_hub import login, hf_hub_download
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import pillow_heif
11
+ import spaces
12
+ import torch
13
+ from gradio_image_annotation import image_annotator
14
+ from gradio_imageslider import ImageSlider
15
+ from PIL import Image
16
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
17
+ from refiners.fluxion.utils import no_grad
18
+ from refiners.solutions import BoxSegmenter
19
+ from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor
20
+ from diffusers import FluxPipeline
21
+
22
+ BoundingBox = tuple[int, int, int, int]
23
+
24
+ pillow_heif.register_heif_opener()
25
+ pillow_heif.register_avif_opener()
26
+
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ # HF 토큰 설정
30
+ HF_TOKEN = os.getenv("HF_TOKEN")
31
+ if HF_TOKEN is None:
32
+ raise ValueError("Please set the HF_TOKEN environment variable")
33
+
34
+ try:
35
+ login(token=HF_TOKEN)
36
+ except Exception as e:
37
+ raise ValueError(f"Failed to login to Hugging Face: {str(e)}")
38
+
39
+ # 모델 초기화
40
+ segmenter = BoxSegmenter(device="cpu")
41
+ segmenter.device = device
42
+ segmenter.model = segmenter.model.to(device=segmenter.device)
43
+
44
+ gd_model_path = "IDEA-Research/grounding-dino-base"
45
+ gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
46
+ gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_dtype=torch.float32)
47
+ gd_model = gd_model.to(device=device)
48
+ assert isinstance(gd_model, GroundingDinoForObjectDetection)
49
+
50
+ # FLUX 파이프라인 초기화
51
+ pipe = FluxPipeline.from_pretrained(
52
+ "black-forest-labs/FLUX.1-dev",
53
+ torch_dtype=torch.bfloat16,
54
+ use_auth_token=HF_TOKEN
55
+ )
56
+ pipe.load_lora_weights(
57
+ hf_hub_download(
58
+ "ByteDance/Hyper-SD",
59
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors",
60
+ use_auth_token=HF_TOKEN
61
+ )
62
+ )
63
+ pipe.fuse_lora(lora_scale=0.125)
64
+ pipe.to(device="cuda", dtype=torch.bfloat16)
65
+
66
+ class timer:
67
+ def __init__(self, method_name="timed process"):
68
+ self.method = method_name
69
+ def __enter__(self):
70
+ self.start = time.time()
71
+ print(f"{self.method} starts")
72
+ def __exit__(self, exc_type, exc_val, exc_tb):
73
+ end = time.time()
74
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
75
+
76
+ def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
77
+ if not bboxes:
78
+ return None
79
+ for bbox in bboxes:
80
+ assert len(bbox) == 4
81
+ assert all(isinstance(x, int) for x in bbox)
82
+ return (
83
+ min(bbox[0] for bbox in bboxes),
84
+ min(bbox[1] for bbox in bboxes),
85
+ max(bbox[2] for bbox in bboxes),
86
+ max(bbox[3] for bbox in bboxes),
87
+ )
88
+
89
+ def corners_to_pixels_format(bboxes: torch.Tensor, width: int, height: int) -> torch.Tensor:
90
+ x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
91
+ return torch.stack((x1.clamp_(0, width), y1.clamp_(0, height), x2.clamp_(0, width), y2.clamp_(0, height)), dim=-1)
92
+
93
+ def gd_detect(img: Image.Image, prompt: str) -> BoundingBox | None:
94
+ inputs = gd_processor(images=img, text=f"{prompt}.", return_tensors="pt").to(device=device)
95
+ with no_grad():
96
+ outputs = gd_model(**inputs)
97
+ width, height = img.size
98
+ results: dict[str, Any] = gd_processor.post_process_grounded_object_detection(
99
+ outputs,
100
+ inputs["input_ids"],
101
+ target_sizes=[(height, width)],
102
+ )[0]
103
+ assert "boxes" in results and isinstance(results["boxes"], torch.Tensor)
104
+ bboxes = corners_to_pixels_format(results["boxes"].cpu(), width, height)
105
+ return bbox_union(bboxes.numpy().tolist())
106
+
107
+ def apply_mask(img: Image.Image, mask_img: Image.Image, defringe: bool = True) -> Image.Image:
108
+ assert img.size == mask_img.size
109
+ img = img.convert("RGB")
110
+ mask_img = mask_img.convert("L")
111
+ if defringe:
112
+ rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0
113
+ foreground = cast(np.ndarray[Any, np.dtype[np.uint8]], estimate_foreground_ml(rgb, alpha))
114
+ img = Image.fromarray((foreground * 255).astype("uint8"))
115
+ result = Image.new("RGBA", img.size)
116
+ result.paste(img, (0, 0), mask_img)
117
+ return result
118
+
119
+ def generate_background(prompt: str, width: int, height: int) -> Image.Image:
120
+ """배경 이미지 생성 함수"""
121
+ try:
122
+ with timer("Background generation"):
123
+ image = pipe(
124
+ prompt=prompt,
125
+ width=width,
126
+ height=height,
127
+ num_inference_steps=8,
128
+ guidance_scale=4.0,
129
+ ).images[0]
130
+ return image
131
+ except Exception as e:
132
+ raise gr.Error(f"Background generation failed: {str(e)}")
133
+
134
+ def combine_with_background(foreground: Image.Image, background: Image.Image) -> Image.Image:
135
+ """전경과 배경 합성 함수"""
136
+ background = background.resize(foreground.size)
137
+ return Image.alpha_composite(background.convert('RGBA'), foreground)
138
+
139
+ @spaces.GPU
140
+ def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Image.Image, BoundingBox | None, list[str]]:
141
+ time_log: list[str] = []
142
+ if isinstance(prompt, str):
143
+ t0 = time.time()
144
+ bbox = gd_detect(img, prompt)
145
+ time_log.append(f"detect: {time.time() - t0}")
146
+ if not bbox:
147
+ print(time_log[0])
148
+ raise gr.Error("No object detected")
149
+ else:
150
+ bbox = prompt
151
+ t0 = time.time()
152
+ mask = segmenter(img, bbox)
153
+ time_log.append(f"segment: {time.time() - t0}")
154
+ return mask, bbox, time_log
155
+
156
+ def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str | None = None) -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]:
157
+ if img.width > 2048 or img.height > 2048:
158
+ orig_res = max(img.width, img.height)
159
+ img.thumbnail((2048, 2048))
160
+ if isinstance(prompt, tuple):
161
+ x0, y0, x1, y1 = (int(x * 2048 / orig_res) for x in prompt)
162
+ prompt = (x0, y0, x1, y1)
163
+
164
+ mask, bbox, time_log = _gpu_process(img, prompt)
165
+ masked_alpha = apply_mask(img, mask, defringe=True)
166
+
167
+ if bg_prompt:
168
+ try:
169
+ background = generate_background(bg_prompt, img.width, img.height)
170
+ combined = combine_with_background(masked_alpha, background)
171
+ except Exception as e:
172
+ raise gr.Error(f"Background processing failed: {str(e)}")
173
+ else:
174
+ combined = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
175
+
176
+ thresholded = mask.point(lambda p: 255 if p > 10 else 0)
177
+ bbox = thresholded.getbbox()
178
+ to_dl = masked_alpha.crop(bbox)
179
+
180
+ temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
181
+ to_dl.save(temp, format="PNG")
182
+ temp.close()
183
+
184
+ return (img, combined, masked_alpha), gr.DownloadButton(value=temp.name, interactive=True)
185
+
186
+ def process_bbox(prompts: dict[str, Any]) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
187
+ assert isinstance(img := prompts["image"], Image.Image)
188
+ assert isinstance(boxes := prompts["boxes"], list)
189
+ if len(boxes) == 1:
190
+ assert isinstance(box := boxes[0], dict)
191
+ bbox = tuple(box[k] for k in ["xmin", "ymin", "xmax", "ymax"])
192
+ else:
193
+ assert len(boxes) == 0
194
+ bbox = None
195
+ return _process(img, bbox)
196
+
197
+ def on_change_bbox(prompts: dict[str, Any] | None):
198
+ return gr.update(interactive=prompts is not None)
199
+
200
+ def process_prompt(img: Image.Image, prompt: str, bg_prompt: str | None = None) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
201
+ return _process(img, prompt, bg_prompt)
202
+
203
+ def on_change_prompt(img: Image.Image | None, prompt: str | None, bg_prompt: str | None = None):
204
+ return gr.update(interactive=bool(img and prompt))
205
+
206
+ def update_button_state(img, prompt):
207
+ return gr.Button.update(interactive=bool(img and prompt))
208
+
209
+
210
+
211
  # Gradio UI 부분 수정
212
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
213
  gr.HTML("""