rizavelioglu commited on
Commit
bcb4ca1
·
1 Parent(s): 0a0606b

handle input validation outside GPU

Browse files

to show the error message in the UI, otherwise generic error is shown: 'ZeroGPU worker error'

Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -47,10 +47,8 @@ def load_model(model_class_name, model_filename, repo_id: str = "rizavelioglu/tr
47
  model.load_state_dict(state_dict, strict=True)
48
  return model.eval()
49
 
50
- @spaces.GPU(duration=10)
51
- @torch.no_grad()
52
- @timer_func
53
- def generate_multi_image(input_image, garment_types, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False):
54
  label_map = {"Upper-Body": 0, "Lower-Body": 1, "Dress": 2}
55
  valid_single = ["Upper-Body", "Lower-Body", "Dress"]
56
  valid_tuple = ["Upper-Body", "Lower-Body"]
@@ -63,7 +61,19 @@ def generate_multi_image(input_image, garment_types, seed=42, guidance_scale=2.0
63
  selected, label_indices = valid_tuple, [label_map[t] for t in valid_tuple]
64
  else:
65
  raise gr.Error("Invalid selection. Choose one garment type or Upper-Body and Lower-Body together.")
 
 
 
 
 
 
 
 
66
 
 
 
 
 
67
  batch_size = len(selected)
68
  scheduler.set_timesteps(num_inference_steps)
69
  generator = torch.Generator(device=device).manual_seed(seed)
@@ -263,10 +273,10 @@ def create_multi_tab():
263
  submit_btn = gr.Button("Generate")
264
  with gr.Column():
265
  output_image = gr.Image(type="pil", label="Generated Garment", height=384, width=384)
266
- gr.Examples(examples=examples, inputs=[input_image, garment_type, seed, guidance_scale, inference_steps, upscale], outputs=output_image, fn=generate_multi_image, cache_examples=False, examples_per_page=2)
267
  gr.Markdown(article)
268
  submit_btn.click(
269
- fn=generate_multi_image,
270
  inputs=[input_image, garment_type, seed, guidance_scale, inference_steps, upscale],
271
  outputs=output_image
272
  )
 
47
  model.load_state_dict(state_dict, strict=True)
48
  return model.eval()
49
 
50
+ def validate_garment_selection(garment_types):
51
+ """Validate garment type selection and return selected types and label indices."""
 
 
52
  label_map = {"Upper-Body": 0, "Lower-Body": 1, "Dress": 2}
53
  valid_single = ["Upper-Body", "Lower-Body", "Dress"]
54
  valid_tuple = ["Upper-Body", "Lower-Body"]
 
61
  selected, label_indices = valid_tuple, [label_map[t] for t in valid_tuple]
62
  else:
63
  raise gr.Error("Invalid selection. Choose one garment type or Upper-Body and Lower-Body together.")
64
+
65
+ return selected, label_indices
66
+
67
+ def generate_multi_image_wrapper(input_image, garment_types, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False):
68
+ """Wrapper function that validates input before calling the GPU function."""
69
+ # Validate selection before entering GPU context
70
+ selected, label_indices = validate_garment_selection(garment_types)
71
+ return generate_multi_image(input_image, selected, label_indices, seed, guidance_scale, num_inference_steps, is_upscale)
72
 
73
+ @spaces.GPU(duration=10)
74
+ @torch.no_grad()
75
+ @timer_func
76
+ def generate_multi_image(input_image, selected, label_indices, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False):
77
  batch_size = len(selected)
78
  scheduler.set_timesteps(num_inference_steps)
79
  generator = torch.Generator(device=device).manual_seed(seed)
 
273
  submit_btn = gr.Button("Generate")
274
  with gr.Column():
275
  output_image = gr.Image(type="pil", label="Generated Garment", height=384, width=384)
276
+ gr.Examples(examples=examples, inputs=[input_image, garment_type, seed, guidance_scale, inference_steps, upscale], outputs=output_image, fn=generate_multi_image_wrapper, cache_examples=False, examples_per_page=2)
277
  gr.Markdown(article)
278
  submit_btn.click(
279
+ fn=generate_multi_image_wrapper,
280
  inputs=[input_image, garment_type, seed, guidance_scale, inference_steps, upscale],
281
  outputs=output_image
282
  )