rizavelioglu commited on
Commit
8eb415a
·
1 Parent(s): 1f9630e

- bump versions

- add v2 models

- add more examples

README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🔥
4
  colorFrom: yellow
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.29.0
8
  app_file: app.py
9
  pinned: true
10
  license: other
 
4
  colorFrom: yellow
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 5.32.1
8
  app_file: app.py
9
  pinned: true
10
  license: other
app.py CHANGED
@@ -1,160 +1,379 @@
1
  import os
 
2
  from pathlib import Path
3
 
4
- from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
5
- from esrgan_model import UpscalerESRGAN
6
- import gradio as gr
7
- from huggingface_hub import hf_hub_download
8
- import spaces
9
  import torch
10
- import torch.nn as nn
11
  from torchvision.io import read_image
12
  import torchvision.transforms.v2 as transforms
13
  from torchvision.utils import make_grid
14
- from transformers import SiglipImageProcessor, SiglipVisionModel
15
-
16
-
17
- class TryOffDiff(nn.Module):
18
- def __init__(self):
19
- super().__init__()
20
- self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
21
- self.transformer = torch.nn.TransformerEncoderLayer(d_model=768, nhead=8, batch_first=True)
22
- self.proj = nn.Linear(1024, 77)
23
- self.norm = nn.LayerNorm(768)
24
 
25
- def adapt_embeddings(self, x):
26
- x = self.transformer(x)
27
- x = self.proj(x.permute(0, 2, 1)).permute(0, 2, 1)
28
- return self.norm(x)
 
29
 
30
- def forward(self, noisy_latents, t, cond_emb):
31
- cond_emb = self.adapt_embeddings(cond_emb)
32
- return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb).sample
33
 
 
34
 
 
35
  class PadToSquare:
36
  def __call__(self, img):
37
- _, h, w = img.shape # Get the original dimensions
38
  max_side = max(h, w)
39
  pad_h = (max_side - h) // 2
40
  pad_w = (max_side - w) // 2
41
  padding = (pad_w, pad_h, max_side - w - pad_w, max_side - h - pad_h)
42
  return transforms.functional.pad(img, padding, padding_mode="edge")
43
 
 
 
 
 
 
 
 
 
44
 
45
- # Set device
46
- device = "cuda" if torch.cuda.is_available() else "cpu"
47
-
48
- # Initialize Image Encoder
49
- img_processor = SiglipImageProcessor.from_pretrained(
50
- "google/siglip-base-patch16-512", do_resize=False, do_rescale=False, do_normalize=False
51
- )
52
- img_enc = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-512").eval().to(device)
53
- img_enc_transform = transforms.Compose(
54
- [
55
- PadToSquare(), # Custom transform to pad the image to a square
56
- transforms.Resize((512, 512)),
57
- transforms.ToDtype(torch.float32, scale=True),
58
- transforms.Normalize(mean=[0.5], std=[0.5]),
59
- ]
60
- )
61
-
62
- # Load TryOffDiff Model
63
- path_model = hf_hub_download(
64
- repo_id="rizavelioglu/tryoffdiff",
65
- filename="tryoffdiff.pth", # or one of ["ldm-1", "ldm-2", "ldm-3", ...],
66
- force_download=False,
67
- )
68
- path_scheduler = hf_hub_download(
69
- repo_id="rizavelioglu/tryoffdiff", filename="scheduler/scheduler_config.json", force_download=False
70
- )
71
- net = TryOffDiff()
72
- net.load_state_dict(torch.load(path_model, weights_only=False))
73
- net.eval().to(device)
74
-
75
- # Initialize VAE (only Decoder will be used)
76
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").eval().to(device)
77
-
78
- # Initialize the upscaler
79
- upscaler = UpscalerESRGAN(
80
- model_path=Path(
81
- hf_hub_download(
82
- repo_id="philz1337x/upscaler",
83
- filename="4x-UltraSharp.pth",
84
- # revision="011deacac8270114eb7d2eeff4fe6fa9a837be70",
85
- )
86
- ),
87
- device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
88
- dtype=torch.float32,
89
- )
 
 
 
 
90
 
91
- torch.cuda.empty_cache()
 
 
92
 
 
 
 
 
 
 
93
 
94
- # Define image generation function
95
  @spaces.GPU(duration=10)
96
  @torch.no_grad()
97
- def generate_image(
98
- input_image, seed: int = 42, guidance_scale: float = 2.0, num_inference_steps: int = 50, is_upscale: bool = False
99
- ):
100
- # Configure scheduler
101
- scheduler = EulerDiscreteScheduler.from_pretrained(path_scheduler)
102
- scheduler.is_scale_input_called = True # suppress warning
103
  scheduler.set_timesteps(num_inference_steps)
104
-
105
- # Set seed for reproducibility
106
  generator = torch.Generator(device=device).manual_seed(seed)
107
  x = torch.randn(1, 4, 64, 64, generator=generator, device=device)
108
 
109
  # Process input image
110
  cond_image = img_enc_transform(read_image(input_image))
111
- inputs = {k: v.to(img_enc.device) for k, v in img_processor(images=cond_image, return_tensors="pt").items()}
112
  cond_emb = img_enc(**inputs).last_hidden_state.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- # Prepare unconditioned embeddings (only if guidance is enabled)
 
 
 
115
  uncond_emb = torch.zeros_like(cond_emb) if guidance_scale > 1 else None
116
 
117
- # Diffusion denoising loop with mixed precision for efficiency
118
  with torch.autocast(device):
119
  for t in scheduler.timesteps:
120
- if guidance_scale > 1:
121
- # Classifier-Free Guidance (CFG)
122
- noise_pred = net(torch.cat([x] * 2), t, torch.cat([uncond_emb, cond_emb])).chunk(2)
123
  noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0])
124
- else:
125
- # Standard prediction
126
- noise_pred = net(x, t, cond_emb)
127
 
128
  # Scheduler step
129
  scheduler_output = scheduler.step(noise_pred, t, x)
130
  x = scheduler_output.prev_sample
131
 
132
  # Decode predictions from latent space
133
- decoded = vae.decode(1 / 0.18215 * scheduler_output.pred_original_sample).sample
134
  images = (decoded / 2 + 0.5).cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- # Create grid
 
 
137
  grid = make_grid(images, nrow=len(images), normalize=True, scale_each=True)
138
  output_image = transforms.ToPILImage()(grid)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- # Optionally upscale the output image
141
- if is_upscale:
142
- output_image = upscaler(output_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- return output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- title = "Virtual Try-Off Generator"
148
- description = r"""
149
- This is the demo of the paper <a href="https://arxiv.org/abs/2411.18350">TryOffDiff: Virtual-Try-Off via High-Fidelity Garment Reconstruction using Diffusion Models</a>.
150
- <br>Upload an image of a clothed individual to generate a standardized garment image using TryOffDiff.
151
- <br> Check out the <a href="https://rizavelioglu.github.io/tryoffdiff/">project page</a> for more information.
 
 
 
 
 
152
  """
153
  article = r"""
154
- Example images are sampled from the `VITON-HD-test` set, which the models did not see during training.
155
-
156
- <br>**Citation** <br>If you find our work useful in your research, please consider giving a star ⭐ and
157
- a citation:
158
  ```
159
  @article{velioglu2024tryoffdiff,
160
  title = {TryOffDiff: Virtual-Try-Off via High-Fidelity Garment Reconstruction using Diffusion Models},
@@ -163,36 +382,82 @@ a citation:
163
  year = {2024},
164
  note = {\url{https://doi.org/nt3n}}
165
  }
 
 
 
 
 
 
 
166
  ```
167
  """
168
- examples = [[f"examples/{img_filename}", 42, 2.0, 20, False] for img_filename in sorted(os.listdir("examples/"))]
169
-
170
- # Create Gradio App
171
- demo = gr.Interface(
172
- fn=generate_image,
173
- inputs=[
174
- gr.Image(type="filepath", label="Reference Image", height=448),
175
- gr.Slider(value=42, minimum=0, maximum=1e6, step=1, label="Seed"),
176
- gr.Slider(
177
- value=2.0,
178
- minimum=1,
179
- maximum=5,
180
- step=0.5,
181
- label="Guidance Scale(s)",
182
- info="No guidance applied at s=1, hence faster inference.",
183
- ),
184
- gr.Slider(value=20, minimum=5, maximum=1000, step=10, label="# of Inference Steps"),
185
- gr.Checkbox(
186
- value=False, label="Upscale Output", info="Upscale output by 4x (2048x2048) using an off-the-shelf model."
187
- ),
188
- ],
189
- outputs=gr.Image(type="pil", label="Generated Garment", height=448),
190
- title=title,
191
- description=description,
192
- article=article,
193
- examples=examples,
194
- examples_per_page=4,
195
- submit_btn="Generate",
196
- )
197
-
198
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import time
3
  from pathlib import Path
4
 
 
 
 
 
 
5
  import torch
 
6
  from torchvision.io import read_image
7
  import torchvision.transforms.v2 as transforms
8
  from torchvision.utils import make_grid
 
 
 
 
 
 
 
 
 
 
9
 
10
+ import gradio as gr
11
+ from diffusers import AutoencoderKL, EulerDiscreteScheduler
12
+ from transformers import SiglipImageProcessor, SiglipVisionModel
13
+ from huggingface_hub import hf_hub_download
14
+ import spaces
15
 
16
+ from esrgan_model import UpscalerESRGAN
17
+ from model import create_model
 
18
 
19
+ device = "cuda"
20
 
21
+ # Custom transform to pad images to square
22
  class PadToSquare:
23
  def __call__(self, img):
24
+ _, h, w = img.shape
25
  max_side = max(h, w)
26
  pad_h = (max_side - h) // 2
27
  pad_w = (max_side - w) // 2
28
  padding = (pad_w, pad_h, max_side - w - pad_w, max_side - h - pad_h)
29
  return transforms.functional.pad(img, padding, padding_mode="edge")
30
 
31
+ # Timer decorator
32
+ def timer_func(func):
33
+ def wrapper(*args, **kwargs):
34
+ t0 = time.time()
35
+ result = func(*args, **kwargs)
36
+ print(f"{func.__name__} took {time.time() - t0:.2f} seconds")
37
+ return result
38
+ return wrapper
39
 
40
+ @timer_func
41
+ def load_model(model_class_name, model_filename, repo_id: str = "rizavelioglu/tryoffdiff"):
42
+ path_model = hf_hub_download(repo_id=repo_id, filename=model_filename, force_download=False)
43
+ state_dict = torch.load(path_model, weights_only=True, map_location=device)
44
+ state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
45
+ model = create_model(model_class_name).to(device)
46
+ # model = torch.compile(model)
47
+ model.load_state_dict(state_dict, strict=True)
48
+ return model.eval()
49
+
50
+ @spaces.GPU(duration=10)
51
+ @torch.no_grad()
52
+ @timer_func
53
+ def generate_multi_image(input_image, garment_types, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False):
54
+ label_map = {"Upper-Body": 0, "Lower-Body": 1, "Dress": 2}
55
+ valid_single = ["Upper-Body", "Lower-Body", "Dress"]
56
+ valid_tuple = ["Upper-Body", "Lower-Body"]
57
+
58
+ if not garment_types:
59
+ raise gr.Error("Please select at least one garment type.")
60
+ if len(garment_types) == 1 and garment_types[0] in valid_single:
61
+ selected, label_indices = garment_types, [label_map[garment_types[0]]]
62
+ elif sorted(garment_types) == sorted(valid_tuple):
63
+ selected, label_indices = valid_tuple, [label_map[t] for t in valid_tuple]
64
+ else:
65
+ raise gr.Error("Invalid selection. Choose one garment type or Upper-Body and Lower-Body together.")
66
+
67
+ batch_size = len(selected)
68
+ scheduler.set_timesteps(num_inference_steps)
69
+ generator = torch.Generator(device=device).manual_seed(seed)
70
+ x = torch.randn(batch_size, 4, 64, 64, generator=generator, device=device)
71
+
72
+ # Process inputs
73
+ cond_image = img_enc_transform(read_image(input_image))
74
+ inputs = {k: v.to(device) for k, v in img_processor(images=cond_image, return_tensors="pt").items()}
75
+ cond_emb = img_enc(**inputs).last_hidden_state.to(device)
76
+ cond_emb = cond_emb.expand(batch_size, *cond_emb.shape[1:])
77
+ uncond_emb = torch.zeros_like(cond_emb) if guidance_scale > 1 else None
78
+ label = torch.tensor(label_indices, device=device, dtype=torch.int64)
79
+ model = models["multi"]
80
+
81
+ with torch.autocast(device):
82
+ for t in scheduler.timesteps:
83
+ t = t.to(device) # Ensure t is on the correct device
84
+ if guidance_scale > 1:
85
+ noise_pred = model(torch.cat([x] * 2), t, torch.cat([uncond_emb, cond_emb]), torch.cat([label, label])).chunk(2)
86
+ noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0]) # Classifier-free guidance
87
+ else:
88
+ noise_pred = model(x, t, cond_emb, label) # Standard prediction
89
 
90
+ # Scheduler step
91
+ scheduler_output = scheduler.step(noise_pred, t, x)
92
+ x = scheduler_output.prev_sample
93
 
94
+ # Decode predictions from latent space
95
+ decoded = vae.decode(1 / vae.config.scaling_factor * scheduler_output.pred_original_sample).sample
96
+ images = (decoded / 2 + 0.5).cpu()
97
+ grid = make_grid(images, nrow=len(images), normalize=True, scale_each=True)
98
+ output_image = transforms.ToPILImage()(grid)
99
+ return upscaler(output_image) if is_upscale else output_image # Optionally upscale the output image
100
 
 
101
  @spaces.GPU(duration=10)
102
  @torch.no_grad()
103
+ @timer_func
104
+ def generate_upper_image(input_image, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False):
105
+ model = models["upper"]
 
 
 
106
  scheduler.set_timesteps(num_inference_steps)
107
+ scheduler.timesteps = scheduler.timesteps.to(device)
 
108
  generator = torch.Generator(device=device).manual_seed(seed)
109
  x = torch.randn(1, 4, 64, 64, generator=generator, device=device)
110
 
111
  # Process input image
112
  cond_image = img_enc_transform(read_image(input_image))
113
+ inputs = {k: v.to(device) for k, v in img_processor(images=cond_image, return_tensors="pt").items()}
114
  cond_emb = img_enc(**inputs).last_hidden_state.to(device)
115
+ uncond_emb = torch.zeros_like(cond_emb) if guidance_scale > 1 else None
116
+
117
+ with torch.autocast(device):
118
+ for t in scheduler.timesteps:
119
+ t = t.to(device) # Ensure t is on the correct device
120
+ if guidance_scale > 1: # Classifier-free guidance
121
+ noise_pred = model(torch.cat([x] * 2), t, torch.cat([uncond_emb, cond_emb])).chunk(2)
122
+ noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0])
123
+ else: # Standard prediction
124
+ noise_pred = model(x, t, cond_emb)
125
+
126
+ # Scheduler step
127
+ scheduler_output = scheduler.step(noise_pred, t, x)
128
+ x = scheduler_output.prev_sample
129
+
130
+ # Decode predictions from latent space
131
+ decoded = vae.decode(1 / vae.config.scaling_factor * scheduler_output.pred_original_sample).sample
132
+ images = (decoded / 2 + 0.5).cpu()
133
+ grid = make_grid(images, nrow=len(images), normalize=True, scale_each=True)
134
+ output_image = transforms.ToPILImage()(grid)
135
+ return upscaler(output_image) if is_upscale else output_image # Optionally upscale the output image
136
+
137
+ @spaces.GPU(duration=10)
138
+ @torch.no_grad()
139
+ @timer_func
140
+ def generate_lower_image(input_image, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False):
141
+ model = models["lower"]
142
+ scheduler.set_timesteps(num_inference_steps)
143
+ scheduler.timesteps = scheduler.timesteps.to(device)
144
+ generator = torch.Generator(device=device).manual_seed(seed)
145
+ x = torch.randn(1, 4, 64, 64, generator=generator, device=device)
146
 
147
+ # Process input image
148
+ cond_image = img_enc_transform(read_image(input_image))
149
+ inputs = {k: v.to(device) for k, v in img_processor(images=cond_image, return_tensors="pt").items()}
150
+ cond_emb = img_enc(**inputs).last_hidden_state.to(device)
151
  uncond_emb = torch.zeros_like(cond_emb) if guidance_scale > 1 else None
152
 
 
153
  with torch.autocast(device):
154
  for t in scheduler.timesteps:
155
+ t = t.to(device) # Ensure t is on the correct device
156
+ if guidance_scale > 1: # Classifier-free guidance
157
+ noise_pred = model(torch.cat([x] * 2), t, torch.cat([uncond_emb, cond_emb])).chunk(2)
158
  noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0])
159
+ else: # Standard prediction
160
+ noise_pred = model(x, t, cond_emb)
 
161
 
162
  # Scheduler step
163
  scheduler_output = scheduler.step(noise_pred, t, x)
164
  x = scheduler_output.prev_sample
165
 
166
  # Decode predictions from latent space
167
+ decoded = vae.decode(1 / vae.config.scaling_factor * scheduler_output.pred_original_sample).sample
168
  images = (decoded / 2 + 0.5).cpu()
169
+ grid = make_grid(images, nrow=len(images), normalize=True, scale_each=True)
170
+ output_image = transforms.ToPILImage()(grid)
171
+ return upscaler(output_image) if is_upscale else output_image # Optionally upscale the output image
172
+
173
+ @spaces.GPU(duration=10)
174
+ @torch.no_grad()
175
+ @timer_func
176
+ def generate_dress_image(input_image, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False):
177
+ model = models["dress"]
178
+ scheduler.set_timesteps(num_inference_steps)
179
+ scheduler.timesteps = scheduler.timesteps.to(device)
180
+ generator = torch.Generator(device=device).manual_seed(seed)
181
+ x = torch.randn(1, 4, 64, 64, generator=generator, device=device)
182
+
183
+ # Process input image
184
+ cond_image = img_enc_transform(read_image(input_image))
185
+ inputs = {k: v.to(device) for k, v in img_processor(images=cond_image, return_tensors="pt").items()}
186
+ cond_emb = img_enc(**inputs).last_hidden_state.to(device)
187
+ uncond_emb = torch.zeros_like(cond_emb) if guidance_scale > 1 else None
188
+
189
+ with torch.autocast(device):
190
+ for t in scheduler.timesteps:
191
+ t = t.to(device) # Ensure t is on the correct device
192
+ if guidance_scale > 1: # Classifier-free guidance
193
+ noise_pred = model(torch.cat([x] * 2), t, torch.cat([uncond_emb, cond_emb])).chunk(2)
194
+ noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0])
195
+ else: # Standard prediction
196
+ noise_pred = model(x, t, cond_emb)
197
+
198
+ # Scheduler step
199
+ scheduler_output = scheduler.step(noise_pred, t, x)
200
+ x = scheduler_output.prev_sample
201
 
202
+ # Decode predictions from latent space
203
+ decoded = vae.decode(1 / vae.config.scaling_factor * scheduler_output.pred_original_sample).sample
204
+ images = (decoded / 2 + 0.5).cpu()
205
  grid = make_grid(images, nrow=len(images), normalize=True, scale_each=True)
206
  output_image = transforms.ToPILImage()(grid)
207
+ return upscaler(output_image) if is_upscale else output_image # Optionally upscale the output image
208
+
209
+ def create_multi_tab():
210
+ description = r"""
211
+ <table class="description-table">
212
+ <tr>
213
+ <td width="50%">
214
+ In total, 4 models are available for generating garments (one in each tab):<br>
215
+ - <b>Multi-Garment</b>: Generate multiple garments (e.g., upper-body and lower-body) sequentially.<br>
216
+ - <b>Upper-Body</b>: Generate upper-body garments (e.g., tops, jackets, etc.).<br>
217
+ - <b>Lower-Body</b>: Generate lower-body garments (e.g., pants, skirts, etc.).<br>
218
+ - <b>Dress</b>: Generate dresses.<br>
219
+ </td>
220
+ <td width="50%">
221
+ <b>How to use:</b><br>
222
+ 1. Upload a reference image,<br>
223
+ 2. Adjust the parameters as needed,<br>
224
+ 3. Click "Generate" to create the garment(s).<br>
225
+ &#128161; Individual models perform slightly better than the multi-garment model, but the latter is more versatile.
226
+ </td>
227
+ </tr>
228
+ </table>
229
+ """
230
+ examples = [
231
+ ["examples/048851_0.jpg", ["Upper-Body", "Lower-Body"], 42, 2.0, 20, False],
232
+ ["examples/048851_0.jpg", ["Upper-Body"], 42, 2.0, 20, False],
233
+ ["examples/048588_0.jpg", ["Upper-Body", "Lower-Body"], 42, 2.0, 20, False],
234
+ ["examples/048588_0.jpg", ["Upper-Body"], 42, 2.0, 20, False],
235
+ ["examples/048643_0.jpg", ["Upper-Body", "Lower-Body"], 42, 2.0, 20, False],
236
+ ["examples/048643_0.jpg", ["Lower-Body"], 42, 2.0, 20, False],
237
+ ["examples/048737_0.jpg", ["Dress"], 42, 2.0, 20, False],
238
+ ["examples/048737_0.jpg", ["Upper-Body", "Lower-Body"], 42, 2.0, 20, False],
239
+ ["examples/048690_0.jpg", ["Upper-Body", "Lower-Body"], 42, 2.0, 20, False],
240
+ ["examples/048690_0.jpg", ["Lower-Body"], 42, 2.0, 20, False],
241
+ ["examples/048691_0.jpg", ["Upper-Body", "Lower-Body"], 42, 2.0, 20, False],
242
+ ["examples/048691_0.jpg", ["Upper-Body"], 42, 2.0, 20, False],
243
+ ["examples/048732_0.jpg", ["Upper-Body", "Lower-Body"], 42, 2.0, 20, False],
244
+ ["examples/048754_0.jpg", ["Upper-Body", "Lower-Body"], 42, 2.0, 20, False],
245
+ ["examples/048799_0.jpg", ["Upper-Body", "Lower-Body"], 42, 2.0, 20, False],
246
+ ["examples/048811_0.jpg", ["Upper-Body", "Lower-Body"], 42, 2.0, 20, False],
247
+ ["examples/048821_0.jpg", ["Upper-Body", "Lower-Body"], 42, 2.0, 20, False],
248
+ ["examples/048821_0.jpg", ["Upper-Body"], 42, 2.0, 20, False],
249
+ ]
250
+
251
+ with gr.Blocks() as tab:
252
+ gr.Markdown(title)
253
+ gr.Markdown(description)
254
+ with gr.Row():
255
+ with gr.Column():
256
+ input_image = gr.Image(type="filepath", label="Reference Image", height=384, width=384)
257
+ with gr.Column(min_width=250):
258
+ garment_type = gr.CheckboxGroup(["Upper-Body", "Lower-Body", "Dress"], label="Select Garment Type", value=["Upper-Body", "Lower-Body"])
259
+ seed = gr.Slider(value=42, minimum=0, maximum=1e6, step=1, label="Seed")
260
+ guidance_scale = gr.Slider(value=2.0, minimum=1, maximum=5, step=0.5, label="Guidance Scale(s)", info="No guidance at s=1.")
261
+ inference_steps = gr.Slider(value=20, minimum=5, maximum=1000, step=10, label="# of Inference Steps")
262
+ upscale = gr.Checkbox(value=False, label="Upscale Output", info="Upscale output by 4x (2048x2048) using an off-the-shelf model.")
263
+ submit_btn = gr.Button("Generate")
264
+ with gr.Column():
265
+ output_image = gr.Image(type="pil", label="Generated Garment", height=384, width=384)
266
+ gr.Examples(examples=examples, inputs=[input_image, garment_type, seed, guidance_scale, inference_steps, upscale], outputs=output_image, fn=generate_multi_image, cache_examples=False, examples_per_page=2)
267
+ gr.Markdown(article)
268
+ submit_btn.click(
269
+ fn=generate_multi_image,
270
+ inputs=[input_image, garment_type, seed, guidance_scale, inference_steps, upscale],
271
+ outputs=output_image
272
+ )
273
+ return tab
274
 
275
+ def create_upper_tab():
276
+ examples = [[f"examples/{img_filename}", 42, 2.0, 20, False] for img_filename in os.listdir("examples/") if img_filename.endswith("_0.jpg")]
277
+ examples += [
278
+ ["examples/00084_00.jpg", 42, 2.0, 20, False],
279
+ ["examples/00254_00.jpg", 42, 2.0, 20, False],
280
+ ["examples/00397_00.jpg", 42, 2.0, 20, False],
281
+ ["examples/01320_00.jpg", 42, 2.0, 20, False],
282
+ ["examples/02390_00.jpg", 42, 2.0, 20, False],
283
+ ["examples/14227_00.jpg", 42, 2.0, 20, False],
284
+ ]
285
+ with gr.Blocks() as tab:
286
+ gr.Markdown(title)
287
+ with gr.Row():
288
+ with gr.Column():
289
+ input_image = gr.Image(type="filepath", label="Reference Image", height=384, width=384)
290
+ with gr.Column(min_width=250):
291
+ seed = gr.Slider(value=42, minimum=0, maximum=1e6, step=1, label="Seed")
292
+ guidance_scale = gr.Slider(value=2.0, minimum=1, maximum=5, step=0.5, label="Guidance Scale(s)", info="No guidance at s=1.")
293
+ inference_steps = gr.Slider(value=20, minimum=5, maximum=1000, step=10, label="# of Inference Steps")
294
+ upscale = gr.Checkbox(value=False, label="Upscale Output", info="Upscale output by 4x (2048x2048) using an off-the-shelf model.")
295
+ submit_btn = gr.Button("Generate")
296
+ with gr.Column():
297
+ output_image = gr.Image(type="pil", label="Generated Garment", height=384, width=384)
298
+ gr.Examples(examples=examples, inputs=[input_image, seed, guidance_scale, inference_steps, upscale], outputs=output_image, fn=generate_upper_image, cache_examples=False, examples_per_page=2)
299
+ gr.Markdown(article)
300
+ submit_btn.click(
301
+ fn=generate_upper_image,
302
+ inputs=[input_image, seed, guidance_scale, inference_steps, upscale],
303
+ outputs=output_image
304
+ )
305
+ return tab
306
 
307
+ def create_lower_tab():
308
+ examples = [[f"examples/{img_filename}", 42, 2.0, 20, False] for img_filename in os.listdir("examples/") if img_filename.endswith("_0.jpg")]
309
+ with gr.Blocks() as tab:
310
+ gr.Markdown(title)
311
+ with gr.Row():
312
+ with gr.Column():
313
+ input_image = gr.Image(type="filepath", label="Reference Image", height=384, width=384)
314
+ with gr.Column(min_width=250):
315
+ seed = gr.Slider(value=42, minimum=0, maximum=1e6, step=1, label="Seed")
316
+ guidance_scale = gr.Slider(value=2.0, minimum=1, maximum=5, step=0.5, label="Guidance Scale(s)", info="No guidance at s=1.")
317
+ inference_steps = gr.Slider(value=20, minimum=5, maximum=1000, step=10, label="# of Inference Steps")
318
+ upscale = gr.Checkbox(value=False, label="Upscale Output", info="Upscale output by 4x (2048x2048) using an off-the-shelf model.")
319
+ submit_btn = gr.Button("Generate")
320
+ with gr.Column():
321
+ output_image = gr.Image(type="pil", label="Generated Garment", height=384, width=384)
322
+ gr.Examples(examples=examples, inputs=[input_image, seed, guidance_scale, inference_steps, upscale], outputs=output_image, fn=generate_lower_image, cache_examples=False, examples_per_page=2)
323
+ gr.Markdown(article)
324
+ submit_btn.click(
325
+ fn=generate_lower_image,
326
+ inputs=[input_image, seed, guidance_scale, inference_steps, upscale],
327
+ outputs=output_image
328
+ )
329
+ return tab
330
 
331
+ def create_dress_tab():
332
+ examples = [
333
+ ["examples/053480_0.jpg", 42, 2.0, 20, False],
334
+ ["examples/048737_0.jpg", 42, 2.0, 20, False],
335
+ ["examples/048811_0.jpg", 42, 2.0, 20, False],
336
+ ["examples/053733_0.jpg", 42, 2.0, 20, False],
337
+ ["examples/052606_0.jpg", 42, 2.0, 20, False],
338
+ ["examples/053682_0.jpg", 42, 2.0, 20, False],
339
+ ["examples/052036_0.jpg", 42, 2.0, 20, False],
340
+ ["examples/052644_0.jpg", 42, 2.0, 20, False],
341
+ ]
342
+ with gr.Blocks() as tab:
343
+ gr.Markdown(title)
344
+ with gr.Row():
345
+ with gr.Column():
346
+ input_image = gr.Image(type="filepath", label="Reference Image", height=384, width=384)
347
+ with gr.Column(min_width=250):
348
+ seed = gr.Slider(value=42, minimum=0, maximum=1e6, step=1, label="Seed")
349
+ guidance_scale = gr.Slider(value=2.0, minimum=1, maximum=5, step=0.5, label="Guidance Scale(s)", info="No guidance at s=1.")
350
+ inference_steps = gr.Slider(value=20, minimum=5, maximum=1000, step=10, label="# of Inference Steps")
351
+ upscale = gr.Checkbox(value=False, label="Upscale Output", info="Upscale output by 4x (2048x2048) using an off-the-shelf model.")
352
+ submit_btn = gr.Button("Generate")
353
+ with gr.Column():
354
+ output_image = gr.Image(type="pil", label="Generated Garment", height=384, width=384)
355
+ gr.Examples(examples=examples, inputs=[input_image, seed, guidance_scale, inference_steps, upscale], outputs=output_image, fn=generate_dress_image, cache_examples=False, examples_per_page=2)
356
+ gr.Markdown(article)
357
+ submit_btn.click(
358
+ fn=generate_dress_image,
359
+ inputs=[input_image, seed, guidance_scale, inference_steps, upscale],
360
+ outputs=output_image
361
+ )
362
+ return tab
363
 
364
+ # UI elements
365
+ title = f"""
366
+ <div class='center-header' style="flex-direction: row; gap: 1.5em;">
367
+ <h1 style="font-size:2.2em; margin-bottom:0.1em;">Virtual Try-Off Generator</h1>
368
+ <a href='https://rizavelioglu.github.io/tryoffdiff' style="align-self:center;">
369
+ <button style="background-color:#1976d2; color:white; font-weight:bold; border:none; border-radius:4px; padding:4px 10px; font-size:1.1em; cursor:pointer;">
370
+ &#128279; Project page
371
+ </button>
372
+ </a>
373
+ </div>
374
  """
375
  article = r"""
376
+ **Citation**<br>If you use this work, please give a star and a citation:
 
 
 
377
  ```
378
  @article{velioglu2024tryoffdiff,
379
  title = {TryOffDiff: Virtual-Try-Off via High-Fidelity Garment Reconstruction using Diffusion Models},
 
382
  year = {2024},
383
  note = {\url{https://doi.org/nt3n}}
384
  }
385
+ @article{velioglu2025enhancing,
386
+ title = {Enhancing Person-to-Person Virtual Try-On with Multi-Garment Virtual Try-Off},
387
+ author = {Velioglu, Riza and Bevandic, Petra and Chan, Robin and Hammer, Barbara},
388
+ journal = {arXiv},
389
+ year = {2025},
390
+ note = {\url{https://doi.org/pn67}}
391
+ }
392
  ```
393
  """
394
+ # Custom CSS for proper styling
395
+ custom_css = """
396
+ .center-header {
397
+ display: flex;
398
+ align-items: center;
399
+ justify-content: center;
400
+ margin: 0 0 20px 0;
401
+ }
402
+ .center-header h1 {
403
+ margin: 0;
404
+ text-align: center;
405
+ }
406
+ .description-table {
407
+ width: 100%;
408
+ border-collapse: collapse;
409
+ }
410
+ .description-table td {
411
+ padding: 10px;
412
+ vertical-align: top;
413
+ }
414
+ """
415
+
416
+ if __name__ == "__main__":
417
+ # Image Encoder and transforms
418
+ img_enc_transform = transforms.Compose(
419
+ [
420
+ PadToSquare(), # Custom transform to pad the image to a square
421
+ transforms.Resize((512, 512)),
422
+ transforms.ToDtype(torch.float32, scale=True),
423
+ transforms.Normalize(mean=[0.5], std=[0.5]),
424
+ ]
425
+ )
426
+ ckpt = "google/siglip-base-patch16-512"
427
+ img_processor = SiglipImageProcessor.from_pretrained(ckpt, do_resize=False, do_rescale=False, do_normalize=False)
428
+ img_enc = SiglipVisionModel.from_pretrained(ckpt).eval().to(device)
429
+
430
+ # Initialize VAE (only Decoder will be used) & Noise Scheduler
431
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").eval().to(device)
432
+ scheduler = EulerDiscreteScheduler.from_pretrained(
433
+ hf_hub_download(repo_id="rizavelioglu/tryoffdiff", filename="scheduler/scheduler_config_v2.json", force_download=False)
434
+ )
435
+ scheduler.is_scale_input_called = True # suppress warning
436
+
437
+ # Upscaler model
438
+ upscaler = UpscalerESRGAN(
439
+ model_path=Path(hf_hub_download(repo_id="philz1337x/upscaler", filename="4x-UltraSharp.pth")),
440
+ device=torch.device(device),
441
+ dtype=torch.float32,
442
+ )
443
+
444
+ # Model configurations and loading
445
+ models = {}
446
+ model_paths = {
447
+ "upper": {"class_name": "TryOffDiffv2_single", "path": "tryoffdiffv2_upper.pth"}, # internal code: model_20250213_134430
448
+ "lower": {"class_name": "TryOffDiffv2_single", "path": "tryoffdiffv2_lower.pth"}, # internal code: model_20250213_134130
449
+ "dress": {"class_name": "TryOffDiffv2_single", "path": "tryoffdiffv2_dress.pth"}, # internal code: model_20250213_133554
450
+ "multi": {"class_name": "TryOffDiffv2", "path": "tryoffdiffv2_multi.pth"}, # internal code: model_20250310_155608
451
+ }
452
+ for name, cfg in model_paths.items():
453
+ models[name] = load_model(cfg["class_name"], cfg["path"])
454
+ torch.cuda.empty_cache()
455
+
456
+ # Create tabbed interface
457
+ demo = gr.TabbedInterface(
458
+ [create_multi_tab(), create_upper_tab(), create_lower_tab(), create_dress_tab()],
459
+ ["Multi-Garment", "Upper-Body", "Lower-Body", "Dress"],
460
+ css=custom_css,
461
+ )
462
+
463
+ demo.launch(ssr_mode=False)
esrgan_model.py CHANGED
@@ -15,7 +15,6 @@ import numpy.typing as npt
15
  import torch
16
  import torch.nn as nn
17
  from PIL import Image
18
- from huggingface_hub import hf_hub_download
19
 
20
 
21
  def conv_block(in_nc: int, out_nc: int) -> nn.Sequential:
 
15
  import torch
16
  import torch.nn as nn
17
  from PIL import Image
 
18
 
19
 
20
  def conv_block(in_nc: int, out_nc: int) -> nn.Sequential:
examples/052036_0.jpg ADDED
examples/052606_0.jpg ADDED
examples/053480_0.jpg ADDED
examples/053682_0.jpg ADDED
model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum, unique
2
+ from typing import Any
3
+
4
+ import torch
5
+ import torchvision.transforms.v2 as transforms
6
+ from diffusers import AutoencoderKL, UNet2DConditionModel, UNet2DModel
7
+ from torch import Tensor, nn
8
+ from transformers import (
9
+ AutoImageProcessor,
10
+ AutoModel,
11
+ AutoProcessor,
12
+ CLIPImageProcessor,
13
+ CLIPVisionModel,
14
+ SiglipImageProcessor,
15
+ SiglipVisionModel,
16
+ )
17
+
18
+
19
+ class TryOffDiff(nn.Module):
20
+ def __init__(self):
21
+ super().__init__()
22
+ self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
23
+ self.transformer = torch.nn.TransformerEncoderLayer(d_model=768, nhead=8, batch_first=True)
24
+ self.proj = nn.Linear(1024, 77)
25
+ self.norm = nn.LayerNorm(768)
26
+
27
+ def forward(self, noisy_latents, t, cond_emb):
28
+ cond_emb = self.transformer(cond_emb)
29
+ cond_emb = self.proj(cond_emb.transpose(1, 2))
30
+ cond_emb = self.norm(cond_emb.transpose(1, 2))
31
+ return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb).sample
32
+
33
+ class TryOffDiffv2(nn.Module):
34
+ def __init__(self):
35
+ super().__init__()
36
+ self.unet = UNet2DConditionModel(
37
+ sample_size=64,
38
+ in_channels=4,
39
+ out_channels=4,
40
+ layers_per_block=2,
41
+ block_out_channels=(320, 640, 1280, 1280),
42
+ down_block_types=(
43
+ "CrossAttnDownBlock2D",
44
+ "CrossAttnDownBlock2D",
45
+ "CrossAttnDownBlock2D",
46
+ "DownBlock2D",
47
+ ),
48
+ up_block_types=(
49
+ "UpBlock2D",
50
+ "CrossAttnUpBlock2D",
51
+ "CrossAttnUpBlock2D",
52
+ "CrossAttnUpBlock2D",
53
+ ),
54
+ cross_attention_dim=768,
55
+ class_embed_type=None,
56
+ num_class_embeds=3,
57
+ )
58
+ # Load the pretrained weights into the custom model, skipping incompatible keys
59
+ pretrained_state_dict = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet").state_dict()
60
+ self.unet.load_state_dict(pretrained_state_dict, strict=False)
61
+
62
+ self.proj = nn.Linear(1024, 77)
63
+ self.norm = nn.LayerNorm(768)
64
+
65
+ def forward(self, noisy_latents, t, cond_emb, class_labels):
66
+ cond_emb = self.proj(cond_emb.transpose(1, 2))
67
+ cond_emb = self.norm(cond_emb.transpose(1, 2))
68
+ return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb, class_labels=class_labels).sample
69
+
70
+ class TryOffDiffv2Single(nn.Module):
71
+ def __init__(self):
72
+ super().__init__()
73
+ self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
74
+ self.proj = nn.Linear(1024, 77)
75
+ self.norm = nn.LayerNorm(768)
76
+
77
+ def forward(self, noisy_latents, t, cond_emb):
78
+ cond_emb = self.proj(cond_emb.transpose(1, 2))
79
+ cond_emb = self.norm(cond_emb.transpose(1, 2))
80
+ return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb).sample
81
+
82
+ @unique
83
+ class ModelName(Enum):
84
+ TryOffDiff = TryOffDiff
85
+ TryOffDiffv2 = TryOffDiffv2
86
+ TryOffDiffv2Single = TryOffDiffv2Single
87
+
88
+ def create_model(model_name: str, **kwargs: Any) -> Any:
89
+ model_class = ModelName[model_name].value
90
+ return model_class(**kwargs)
requirements.txt CHANGED
@@ -1,8 +1,6 @@
1
- torch>=2.4.0
2
  torchvision>=0.20.1
3
- diffusers>=0.31.0
4
- transformers>=4.46.3
5
- gradio>=5.7.0
6
- spaces>=0.30.4
7
- huggingface-hub>=0.26.2
8
- accelerate>=1.1.1
 
1
+ torch>=2.5.1
2
  torchvision>=0.20.1
3
+ diffusers>=0.33.1
4
+ transformers>=4.49.0
5
+ huggingface-hub>=0.30.2
6
+ accelerate>=1.2.1