FrozenBurning commited on
Commit
6cf1b17
·
1 Parent(s): eb61402

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -17
app.py CHANGED
@@ -74,9 +74,21 @@ config.model.pop("latent_std")
74
  model_primx = load_from_config(config.model)
75
  # load rembg
76
  rembg_session = rembg.new_session()
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  # process function
79
- def process(input_image, input_num_steps, input_seed=42, input_cfg=6.0):
80
  # seed
81
  torch.manual_seed(input_seed)
82
 
@@ -91,16 +103,8 @@ def process(input_image, input_num_steps, input_seed=42, input_cfg=6.0):
91
  fwd_fn = model.forward_with_cfg
92
 
93
  # text-conditioned
94
- if input_image is None:
95
  raise NotImplementedError
96
- # image-conditioned (may also input text, but no text usually works too)
97
- else:
98
- input_image = remove_background(input_image, rembg_session)
99
- input_image = resize_foreground(input_image, 0.85)
100
- raw_image = np.array(input_image)
101
- mask = (raw_image[..., -1][..., None] > 0) * 1
102
- raw_image = raw_image[..., :3] * mask
103
- input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device)
104
 
105
  with torch.no_grad():
106
  latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4)
@@ -178,8 +182,11 @@ with block:
178
 
179
  with gr.Row(variant='panel'):
180
  with gr.Column(scale=1):
181
- # input image
182
- input_image = gr.Image(label="image", type='pil')
 
 
 
183
  # inference steps
184
  input_num_steps = gr.Radio(choices=[25, 50, 100, 200], label="DDIM steps", value=25)
185
  # random seed
@@ -187,7 +194,7 @@ with block:
187
  # random seed
188
  input_seed = gr.Slider(label="random seed", minimum=0, maximum=10000, step=1, value=42, info="Try different seed if the result is not satisfying as this is a generative model!")
189
  # gen button
190
- button_gen = gr.Button("Generate")
191
  export_glb_btn = gr.Button(value="Export GLB", interactive=False)
192
 
193
  with gr.Column(scale=1):
@@ -231,15 +238,16 @@ with block:
231
  outputs=[output_glb],
232
  )
233
 
234
- button_gen.click(process, inputs=[input_image, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, export_glb_btn])
 
 
235
 
236
  export_glb_btn.click(export_mesh, inputs=[], outputs=[output_glb, hdr_row])
237
 
238
  gr.Examples(
239
  examples=[
240
- "assets/examples/fruit_elephant.jpg",
241
- "assets/examples/mei_ling_panda.png",
242
- "assets/examples/shuai_panda_notail.png",
243
  ],
244
  inputs=[input_image],
245
  outputs=[output_rgb_video, output_prim_video, output_mat_video, export_glb_btn],
 
74
  model_primx = load_from_config(config.model)
75
  # load rembg
76
  rembg_session = rembg.new_session()
77
+ current_fg_state = None
78
+
79
+ # background removal function
80
+ def background_remove_process(input_image):
81
+ input_image = remove_background(input_image, rembg_session)
82
+ input_image = resize_foreground(input_image, 0.85)
83
+ input_cond_preview_pil = input_image
84
+ raw_image = np.array(input_image)
85
+ mask = (raw_image[..., -1][..., None] > 0) * 1
86
+ raw_image = raw_image[..., :3] * mask
87
+ input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device)
88
+ return gr.update(interactive=True), input_cond, input_cond_preview_pil
89
 
90
  # process function
91
+ def process(input_cond, input_num_steps, input_seed=42, input_cfg=6.0):
92
  # seed
93
  torch.manual_seed(input_seed)
94
 
 
103
  fwd_fn = model.forward_with_cfg
104
 
105
  # text-conditioned
106
+ if input_cond is None:
107
  raise NotImplementedError
 
 
 
 
 
 
 
 
108
 
109
  with torch.no_grad():
110
  latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4)
 
182
 
183
  with gr.Row(variant='panel'):
184
  with gr.Column(scale=1):
185
+ with gr.Row():
186
+ # input image
187
+ input_image = gr.Image(label="image", type='pil')
188
+ # background removal
189
+ removal_previewer = gr.Image(label="Background Removal Preview", type='pil', interactive=False)
190
  # inference steps
191
  input_num_steps = gr.Radio(choices=[25, 50, 100, 200], label="DDIM steps", value=25)
192
  # random seed
 
194
  # random seed
195
  input_seed = gr.Slider(label="random seed", minimum=0, maximum=10000, step=1, value=42, info="Try different seed if the result is not satisfying as this is a generative model!")
196
  # gen button
197
+ button_gen = gr.Button(value="Generate", interactive=False)
198
  export_glb_btn = gr.Button(value="Export GLB", interactive=False)
199
 
200
  with gr.Column(scale=1):
 
238
  outputs=[output_glb],
239
  )
240
 
241
+ input_image.change(background_remove_process, inputs=[input_image], outputs=[button_gen, current_fg_state, removal_previewer])
242
+
243
+ button_gen.click(process, inputs=[current_fg_state, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, export_glb_btn])
244
 
245
  export_glb_btn.click(export_mesh, inputs=[], outputs=[output_glb, hdr_row])
246
 
247
  gr.Examples(
248
  examples=[
249
+ os.path.join("assets/examples", f)
250
+ for f in os.listdir("assets/examples")
 
251
  ],
252
  inputs=[input_image],
253
  outputs=[output_rgb_video, output_prim_video, output_mat_video, export_glb_btn],