Spaces:
Running
on
Zero
Running
on
Zero
rizavelioglu
commited on
Commit
·
bcb4ca1
1
Parent(s):
0a0606b
handle input validation outside GPU
Browse filesto show the error message in the UI, otherwise generic error is shown: 'ZeroGPU worker error'
app.py
CHANGED
@@ -47,10 +47,8 @@ def load_model(model_class_name, model_filename, repo_id: str = "rizavelioglu/tr
|
|
47 |
model.load_state_dict(state_dict, strict=True)
|
48 |
return model.eval()
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
@timer_func
|
53 |
-
def generate_multi_image(input_image, garment_types, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False):
|
54 |
label_map = {"Upper-Body": 0, "Lower-Body": 1, "Dress": 2}
|
55 |
valid_single = ["Upper-Body", "Lower-Body", "Dress"]
|
56 |
valid_tuple = ["Upper-Body", "Lower-Body"]
|
@@ -63,7 +61,19 @@ def generate_multi_image(input_image, garment_types, seed=42, guidance_scale=2.0
|
|
63 |
selected, label_indices = valid_tuple, [label_map[t] for t in valid_tuple]
|
64 |
else:
|
65 |
raise gr.Error("Invalid selection. Choose one garment type or Upper-Body and Lower-Body together.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
|
|
|
|
|
|
|
|
67 |
batch_size = len(selected)
|
68 |
scheduler.set_timesteps(num_inference_steps)
|
69 |
generator = torch.Generator(device=device).manual_seed(seed)
|
@@ -263,10 +273,10 @@ def create_multi_tab():
|
|
263 |
submit_btn = gr.Button("Generate")
|
264 |
with gr.Column():
|
265 |
output_image = gr.Image(type="pil", label="Generated Garment", height=384, width=384)
|
266 |
-
gr.Examples(examples=examples, inputs=[input_image, garment_type, seed, guidance_scale, inference_steps, upscale], outputs=output_image, fn=
|
267 |
gr.Markdown(article)
|
268 |
submit_btn.click(
|
269 |
-
fn=
|
270 |
inputs=[input_image, garment_type, seed, guidance_scale, inference_steps, upscale],
|
271 |
outputs=output_image
|
272 |
)
|
|
|
47 |
model.load_state_dict(state_dict, strict=True)
|
48 |
return model.eval()
|
49 |
|
50 |
+
def validate_garment_selection(garment_types):
|
51 |
+
"""Validate garment type selection and return selected types and label indices."""
|
|
|
|
|
52 |
label_map = {"Upper-Body": 0, "Lower-Body": 1, "Dress": 2}
|
53 |
valid_single = ["Upper-Body", "Lower-Body", "Dress"]
|
54 |
valid_tuple = ["Upper-Body", "Lower-Body"]
|
|
|
61 |
selected, label_indices = valid_tuple, [label_map[t] for t in valid_tuple]
|
62 |
else:
|
63 |
raise gr.Error("Invalid selection. Choose one garment type or Upper-Body and Lower-Body together.")
|
64 |
+
|
65 |
+
return selected, label_indices
|
66 |
+
|
67 |
+
def generate_multi_image_wrapper(input_image, garment_types, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False):
|
68 |
+
"""Wrapper function that validates input before calling the GPU function."""
|
69 |
+
# Validate selection before entering GPU context
|
70 |
+
selected, label_indices = validate_garment_selection(garment_types)
|
71 |
+
return generate_multi_image(input_image, selected, label_indices, seed, guidance_scale, num_inference_steps, is_upscale)
|
72 |
|
73 |
+
@spaces.GPU(duration=10)
|
74 |
+
@torch.no_grad()
|
75 |
+
@timer_func
|
76 |
+
def generate_multi_image(input_image, selected, label_indices, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False):
|
77 |
batch_size = len(selected)
|
78 |
scheduler.set_timesteps(num_inference_steps)
|
79 |
generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
273 |
submit_btn = gr.Button("Generate")
|
274 |
with gr.Column():
|
275 |
output_image = gr.Image(type="pil", label="Generated Garment", height=384, width=384)
|
276 |
+
gr.Examples(examples=examples, inputs=[input_image, garment_type, seed, guidance_scale, inference_steps, upscale], outputs=output_image, fn=generate_multi_image_wrapper, cache_examples=False, examples_per_page=2)
|
277 |
gr.Markdown(article)
|
278 |
submit_btn.click(
|
279 |
+
fn=generate_multi_image_wrapper,
|
280 |
inputs=[input_image, garment_type, seed, guidance_scale, inference_steps, upscale],
|
281 |
outputs=output_image
|
282 |
)
|