FrozenBurning
commited on
Commit
·
6cf1b17
1
Parent(s):
eb61402
Update app.py
Browse files
app.py
CHANGED
@@ -74,9 +74,21 @@ config.model.pop("latent_std")
|
|
74 |
model_primx = load_from_config(config.model)
|
75 |
# load rembg
|
76 |
rembg_session = rembg.new_session()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
# process function
|
79 |
-
def process(
|
80 |
# seed
|
81 |
torch.manual_seed(input_seed)
|
82 |
|
@@ -91,16 +103,8 @@ def process(input_image, input_num_steps, input_seed=42, input_cfg=6.0):
|
|
91 |
fwd_fn = model.forward_with_cfg
|
92 |
|
93 |
# text-conditioned
|
94 |
-
if
|
95 |
raise NotImplementedError
|
96 |
-
# image-conditioned (may also input text, but no text usually works too)
|
97 |
-
else:
|
98 |
-
input_image = remove_background(input_image, rembg_session)
|
99 |
-
input_image = resize_foreground(input_image, 0.85)
|
100 |
-
raw_image = np.array(input_image)
|
101 |
-
mask = (raw_image[..., -1][..., None] > 0) * 1
|
102 |
-
raw_image = raw_image[..., :3] * mask
|
103 |
-
input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device)
|
104 |
|
105 |
with torch.no_grad():
|
106 |
latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4)
|
@@ -178,8 +182,11 @@ with block:
|
|
178 |
|
179 |
with gr.Row(variant='panel'):
|
180 |
with gr.Column(scale=1):
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
183 |
# inference steps
|
184 |
input_num_steps = gr.Radio(choices=[25, 50, 100, 200], label="DDIM steps", value=25)
|
185 |
# random seed
|
@@ -187,7 +194,7 @@ with block:
|
|
187 |
# random seed
|
188 |
input_seed = gr.Slider(label="random seed", minimum=0, maximum=10000, step=1, value=42, info="Try different seed if the result is not satisfying as this is a generative model!")
|
189 |
# gen button
|
190 |
-
button_gen = gr.Button("Generate")
|
191 |
export_glb_btn = gr.Button(value="Export GLB", interactive=False)
|
192 |
|
193 |
with gr.Column(scale=1):
|
@@ -231,15 +238,16 @@ with block:
|
|
231 |
outputs=[output_glb],
|
232 |
)
|
233 |
|
234 |
-
|
|
|
|
|
235 |
|
236 |
export_glb_btn.click(export_mesh, inputs=[], outputs=[output_glb, hdr_row])
|
237 |
|
238 |
gr.Examples(
|
239 |
examples=[
|
240 |
-
"assets/examples
|
241 |
-
"assets/examples
|
242 |
-
"assets/examples/shuai_panda_notail.png",
|
243 |
],
|
244 |
inputs=[input_image],
|
245 |
outputs=[output_rgb_video, output_prim_video, output_mat_video, export_glb_btn],
|
|
|
74 |
model_primx = load_from_config(config.model)
|
75 |
# load rembg
|
76 |
rembg_session = rembg.new_session()
|
77 |
+
current_fg_state = None
|
78 |
+
|
79 |
+
# background removal function
|
80 |
+
def background_remove_process(input_image):
|
81 |
+
input_image = remove_background(input_image, rembg_session)
|
82 |
+
input_image = resize_foreground(input_image, 0.85)
|
83 |
+
input_cond_preview_pil = input_image
|
84 |
+
raw_image = np.array(input_image)
|
85 |
+
mask = (raw_image[..., -1][..., None] > 0) * 1
|
86 |
+
raw_image = raw_image[..., :3] * mask
|
87 |
+
input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device)
|
88 |
+
return gr.update(interactive=True), input_cond, input_cond_preview_pil
|
89 |
|
90 |
# process function
|
91 |
+
def process(input_cond, input_num_steps, input_seed=42, input_cfg=6.0):
|
92 |
# seed
|
93 |
torch.manual_seed(input_seed)
|
94 |
|
|
|
103 |
fwd_fn = model.forward_with_cfg
|
104 |
|
105 |
# text-conditioned
|
106 |
+
if input_cond is None:
|
107 |
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
with torch.no_grad():
|
110 |
latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4)
|
|
|
182 |
|
183 |
with gr.Row(variant='panel'):
|
184 |
with gr.Column(scale=1):
|
185 |
+
with gr.Row():
|
186 |
+
# input image
|
187 |
+
input_image = gr.Image(label="image", type='pil')
|
188 |
+
# background removal
|
189 |
+
removal_previewer = gr.Image(label="Background Removal Preview", type='pil', interactive=False)
|
190 |
# inference steps
|
191 |
input_num_steps = gr.Radio(choices=[25, 50, 100, 200], label="DDIM steps", value=25)
|
192 |
# random seed
|
|
|
194 |
# random seed
|
195 |
input_seed = gr.Slider(label="random seed", minimum=0, maximum=10000, step=1, value=42, info="Try different seed if the result is not satisfying as this is a generative model!")
|
196 |
# gen button
|
197 |
+
button_gen = gr.Button(value="Generate", interactive=False)
|
198 |
export_glb_btn = gr.Button(value="Export GLB", interactive=False)
|
199 |
|
200 |
with gr.Column(scale=1):
|
|
|
238 |
outputs=[output_glb],
|
239 |
)
|
240 |
|
241 |
+
input_image.change(background_remove_process, inputs=[input_image], outputs=[button_gen, current_fg_state, removal_previewer])
|
242 |
+
|
243 |
+
button_gen.click(process, inputs=[current_fg_state, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, export_glb_btn])
|
244 |
|
245 |
export_glb_btn.click(export_mesh, inputs=[], outputs=[output_glb, hdr_row])
|
246 |
|
247 |
gr.Examples(
|
248 |
examples=[
|
249 |
+
os.path.join("assets/examples", f)
|
250 |
+
for f in os.listdir("assets/examples")
|
|
|
251 |
],
|
252 |
inputs=[input_image],
|
253 |
outputs=[output_rgb_video, output_prim_video, output_mat_video, export_glb_btn],
|