ixarchakos commited on
Commit
cb6ddcb
·
verified ·
1 Parent(s): 2ace5b2

Upload 11 files

Browse files
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict
2
+
3
+ import diffusers.image_processor
4
+ import gradio as gr
5
+ import pillow_heif
6
+ import spaces
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
10
+ from pipeline import TryOffAnyone
11
+ import numpy as np
12
+
13
+
14
+ pillow_heif.register_heif_opener()
15
+ pillow_heif.register_avif_opener()
16
+
17
+ torch.set_float32_matmul_precision("high")
18
+ torch.backends.cuda.matmul.allow_tf32 = True
19
+
20
+ TITLE = """
21
+ # Try Off Anyone
22
+
23
+ ## Important
24
+
25
+ 1. Choose an example image or upload your own
26
+
27
+ [[arxiv:2412.08573]](https://arxiv.org/abs/2412.08573)
28
+ [[github:ixarchakos/try-off-anyone]](https://github.com/ixarchakos/try-off-anyone)
29
+ """
30
+
31
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps")
32
+ DTYPE = torch.bfloat16 if DEVICE == 'cuda' else torch.float32
33
+
34
+ pipeline_tryoff = TryOffAnyone(
35
+ device=DEVICE,
36
+ dtype=DTYPE,
37
+ )
38
+ mask_processor = diffusers.image_processor.VaeImageProcessor(
39
+ vae_scale_factor=8,
40
+ do_normalize=False,
41
+ do_binarize=True,
42
+ do_convert_grayscale=True,
43
+ )
44
+ vae_processor = diffusers.image_processor.VaeImageProcessor(
45
+ vae_scale_factor=8,
46
+ )
47
+
48
+
49
+ def mask_generation(image, processor, model, category):
50
+ inputs = processor(images=image, return_tensors="pt")
51
+ outputs = model(**inputs)
52
+ logits = outputs.logits.cpu()
53
+
54
+ upsampled_logits = torch.nn.functional.interpolate(
55
+ logits,
56
+ size=image.size[::-1],
57
+ mode="bilinear",
58
+ align_corners=False,
59
+ )
60
+
61
+ predicted_mask = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy()
62
+ if category == "Tops":
63
+ predicted_mask_1 = predicted_mask == 4
64
+ predicted_mask_2 = predicted_mask == 7
65
+ elif category == "Bottoms":
66
+ predicted_mask_1 = predicted_mask == 5
67
+ predicted_mask_2 = predicted_mask == 6
68
+ else:
69
+ raise NotImplementedError
70
+
71
+ predicted_mask = predicted_mask_1 + predicted_mask_2
72
+ mask_image = Image.fromarray((predicted_mask * 255).astype(np.uint8))
73
+ return mask_image
74
+
75
+
76
+ class ImageData(TypedDict):
77
+ background: Image.Image
78
+ composite: Image.Image
79
+ layers: list[Image.Image]
80
+
81
+
82
+ @spaces.GPU
83
+ def process(
84
+ image_data: ImageData,
85
+ image_width: int,
86
+ image_height: int,
87
+ num_inference_steps: int,
88
+ condition_scale: float,
89
+ seed: int,
90
+ ) -> Image.Image:
91
+ assert image_width > 0
92
+ assert image_height > 0
93
+ assert num_inference_steps > 0
94
+ assert condition_scale > 0
95
+ assert seed >= 0
96
+
97
+ # extract image and mask from image_data
98
+ image = image_data["background"]
99
+ processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer_b3_clothes")
100
+ model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer_b3_clothes")
101
+ model.to('cpu')
102
+
103
+ # preprocess image
104
+ image = image.convert("RGB").resize((image_width, image_height))
105
+ mask = mask_generation(image, processor, model, "Tops")
106
+ image_preprocessed = vae_processor.preprocess(
107
+ image=image,
108
+ width=image_width,
109
+ height=image_height,
110
+ )[0]
111
+
112
+ # preprocess mask
113
+ mask = mask.resize((image_width, image_height))
114
+ mask_preprocessed = mask_processor.preprocess( # pyright: ignore[reportUnknownMemberType]
115
+ image=mask,
116
+ width=image_width,
117
+ height=image_height,
118
+ )[0]
119
+
120
+ # generate the TryOff image
121
+ gen = torch.Generator(device=DEVICE).manual_seed(seed)
122
+ tryoff_image = pipeline_tryoff(
123
+ image_preprocessed,
124
+ mask_preprocessed,
125
+ inference_steps=num_inference_steps,
126
+ scale=condition_scale,
127
+ generator=gen,
128
+ )[0]
129
+
130
+ return tryoff_image
131
+
132
+
133
+ with gr.Blocks() as demo:
134
+ gr.Markdown(TITLE)
135
+
136
+ with gr.Row():
137
+ with gr.Column():
138
+ input_image = gr.ImageMask(
139
+ label="Input Image",
140
+ height=1024,
141
+ type="pil",
142
+ interactive=True,
143
+ )
144
+ run_button = gr.Button(
145
+ value="Extract Clothing",
146
+ )
147
+ gr.Examples(
148
+ examples=[
149
+ ["examples/model_1.jpg"],
150
+ ["examples/model_2.jpg"],
151
+ ["examples/model_3.jpg"],
152
+ ["examples/model_4.jpg"],
153
+ ["examples/model_5.jpg"],
154
+ ["examples/model_6.jpg"],
155
+ ["examples/model_7.jpg"],
156
+ ["examples/model_8.jpg"],
157
+ ["examples/model_9.jpg"],
158
+ ],
159
+ inputs=[input_image],
160
+ )
161
+ with gr.Column():
162
+ output_image = gr.Image(
163
+ label="TryOff result",
164
+ height=1024,
165
+ image_mode="RGB",
166
+ type="pil",
167
+ )
168
+
169
+ with gr.Accordion("Advanced Settings", open=False):
170
+ seed = gr.Slider(
171
+ label="Seed",
172
+ minimum=36,
173
+ maximum=36,
174
+ value=36,
175
+ step=1,
176
+ )
177
+ scale = gr.Slider(
178
+ label="Scale",
179
+ minimum=2.5,
180
+ maximum=2.5,
181
+ value=2.5,
182
+ step=0,
183
+ )
184
+ num_inference_steps = gr.Slider(
185
+ label="Number of inference steps",
186
+ minimum=50,
187
+ maximum=50,
188
+ value=50,
189
+ step=1,
190
+ )
191
+ with gr.Row():
192
+ image_width = gr.Slider(
193
+ label="Image Width",
194
+ minimum=384,
195
+ maximum=384,
196
+ value=384,
197
+ step=8,
198
+ )
199
+ image_height = gr.Slider(
200
+ label="Image Height",
201
+ minimum=512,
202
+ maximum=512,
203
+ value=512,
204
+ step=8,
205
+ )
206
+
207
+ run_button.click(
208
+ fn=process,
209
+ inputs=[
210
+ input_image,
211
+ image_width,
212
+ image_height,
213
+ num_inference_steps,
214
+ scale,
215
+ seed,
216
+ ],
217
+ outputs=output_image,
218
+ )
219
+
220
+ demo.launch()
examples/model_1.jpg ADDED
examples/model_2.jpg ADDED
examples/model_3.jpg ADDED
examples/model_4.jpg ADDED
examples/model_5.jpg ADDED
examples/model_6.jpg ADDED
examples/model_7.jpg ADDED
examples/model_8.jpg ADDED
examples/model_9.jpg ADDED
pipeline.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # type: ignore
2
+ # Inspired from https://github.com/ixarchakos/try-off-anyone/blob/aa3045453013065573a647e4536922bac696b968/src/model/pipeline.py
3
+ # Inspired from https://github.com/ixarchakos/try-off-anyone/blob/aa3045453013065573a647e4536922bac696b968/src/model/attention.py
4
+
5
+ import torch
6
+ from accelerate import load_checkpoint_in_model
7
+ from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel
8
+ from diffusers.models.attention_processor import AttnProcessor
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from huggingface_hub import hf_hub_download
11
+ from PIL import Image
12
+
13
+
14
+ class Skip(torch.nn.Module):
15
+ def __init__(self) -> None:
16
+ super().__init__()
17
+
18
+ def __call__(
19
+ self,
20
+ attn: torch.Tensor,
21
+ hidden_states: torch.Tensor,
22
+ encoder_hidden_states: torch.Tensor = None,
23
+ attention_mask: torch.Tensor = None,
24
+ temb: torch.Tensor = None,
25
+ ) -> torch.Tensor:
26
+ return hidden_states
27
+
28
+
29
+ def fine_tuned_modules(unet: UNet2DConditionModel) -> torch.nn.ModuleList:
30
+ trainable_modules = torch.nn.ModuleList()
31
+
32
+ for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]:
33
+ if hasattr(blocks, "attentions"):
34
+ trainable_modules.append(blocks.attentions)
35
+ else:
36
+ for block in blocks:
37
+ if hasattr(block, "attentions"):
38
+ trainable_modules.append(block.attentions)
39
+
40
+ return trainable_modules
41
+
42
+
43
+ def skip_cross_attentions(unet: UNet2DConditionModel) -> dict[str, AttnProcessor | Skip]:
44
+ attn_processors = {
45
+ name: unet.attn_processors[name] if name.endswith("attn1.processor") else Skip()
46
+ for name in unet.attn_processors.keys()
47
+ }
48
+ return attn_processors
49
+
50
+
51
+ def encode(image: torch.Tensor, vae: AutoencoderKL) -> torch.Tensor:
52
+ image = image.to(memory_format=torch.contiguous_format).float().to(vae.device, dtype=vae.dtype)
53
+ with torch.no_grad():
54
+ return vae.encode(image).latent_dist.sample() * vae.config.scaling_factor
55
+
56
+
57
+ class TryOffAnyone:
58
+ def __init__(
59
+ self,
60
+ device: torch.device,
61
+ dtype: torch.dtype,
62
+ concat_dim: int = -2,
63
+ ) -> None:
64
+ self.concat_dim = concat_dim
65
+ self.device = device
66
+ self.dtype = dtype
67
+
68
+ self.noise_scheduler = DDIMScheduler.from_pretrained(
69
+ pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-inpainting",
70
+ subfolder="scheduler",
71
+ )
72
+ self.vae = AutoencoderKL.from_pretrained(
73
+ pretrained_model_name_or_path="stabilityai/sd-vae-ft-mse",
74
+ ).to(device, dtype=dtype)
75
+ self.unet = UNet2DConditionModel.from_pretrained(
76
+ pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-inpainting",
77
+ subfolder="unet",
78
+ variant="fp16",
79
+ ).to(device, dtype=dtype)
80
+
81
+ self.unet.set_attn_processor(skip_cross_attentions(self.unet))
82
+ load_checkpoint_in_model(
83
+ model=fine_tuned_modules(unet=self.unet),
84
+ checkpoint=hf_hub_download(
85
+ repo_id="ixarchakos/tryOffAnyone",
86
+ filename="model.safetensors",
87
+ ),
88
+ )
89
+
90
+ @torch.no_grad()
91
+ def __call__(
92
+ self,
93
+ image: torch.Tensor,
94
+ mask: torch.Tensor,
95
+ inference_steps: int,
96
+ scale: float,
97
+ generator: torch.Generator,
98
+ ) -> list[Image.Image]:
99
+ image = image.unsqueeze(0).to(self.device, dtype=self.dtype)
100
+ mask = (mask.unsqueeze(0) > 0.5).to(self.device, dtype=self.dtype)
101
+ masked_image = image * (mask < 0.5)
102
+
103
+ masked_latent = encode(masked_image, self.vae)
104
+ image_latent = encode(image, self.vae)
105
+ mask = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest")
106
+
107
+ masked_latent_concat = torch.cat([masked_latent, image_latent], dim=self.concat_dim)
108
+ mask_concat = torch.cat([mask, torch.zeros_like(mask)], dim=self.concat_dim)
109
+
110
+ latents = randn_tensor(
111
+ shape=masked_latent_concat.shape,
112
+ generator=generator,
113
+ device=self.device,
114
+ dtype=self.dtype,
115
+ )
116
+
117
+ self.noise_scheduler.set_timesteps(inference_steps, device=self.device)
118
+ timesteps = self.noise_scheduler.timesteps
119
+
120
+ if do_classifier_free_guidance := (scale > 1.0):
121
+ masked_latent_concat = torch.cat(
122
+ [
123
+ torch.cat([masked_latent, torch.zeros_like(image_latent)], dim=self.concat_dim),
124
+ masked_latent_concat,
125
+ ]
126
+ )
127
+
128
+ mask_concat = torch.cat([mask_concat] * 2)
129
+
130
+ extra_step = {"generator": generator, "eta": 1.0}
131
+ for t in timesteps:
132
+ input_latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
133
+ input_latents = self.noise_scheduler.scale_model_input(input_latents, t)
134
+
135
+ input_latents = torch.cat([input_latents, mask_concat, masked_latent_concat], dim=1)
136
+
137
+ noise_pred = self.unet(
138
+ input_latents,
139
+ t.to(self.device),
140
+ encoder_hidden_states=None,
141
+ return_dict=False,
142
+ )[0]
143
+
144
+ if do_classifier_free_guidance:
145
+ noise_pred_unc, noise_pred_text = noise_pred.chunk(2)
146
+ noise_pred = noise_pred_unc + scale * (noise_pred_text - noise_pred_unc)
147
+
148
+ latents = self.noise_scheduler.step(noise_pred, t, latents, **extra_step).prev_sample
149
+
150
+ latents = latents.split(latents.shape[self.concat_dim] // 2, dim=self.concat_dim)[0]
151
+ latents = 1 / self.vae.config.scaling_factor * latents
152
+ image = self.vae.decode(latents.to(self.device, dtype=self.dtype)).sample
153
+ image = (image / 2 + 0.5).clamp(0, 1)
154
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
155
+
156
+ image = (image * 255).round().astype("uint8")
157
+ image = [Image.fromarray(im) for im in image]
158
+
159
+ return image