fix app torch grad
Browse files
app.py
CHANGED
@@ -83,12 +83,11 @@ def text_to_image_generation(input_text, guidance_scale=1.75, generation_timeste
|
|
83 |
config=config,
|
84 |
)
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
92 |
|
93 |
return images[0]
|
94 |
|
@@ -158,12 +157,12 @@ def text_guided_inpainting(input_text, inpainting_image, inpainting_mask, guidan
|
|
158 |
config=config,
|
159 |
)
|
160 |
|
161 |
-
|
162 |
-
|
163 |
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
|
168 |
return images[0]
|
169 |
|
@@ -283,11 +282,12 @@ def text_guided_extrapolation(input_img, input_text, left_ext, right_ext, guidan
|
|
283 |
|
284 |
_, h, w = gen_token_ids.shape
|
285 |
gen_token_ids = gen_token_ids.reshape(config.training.batch_size, -1)
|
286 |
-
images = vq_model.decode_code(gen_token_ids, shape=(h, w))
|
287 |
|
288 |
-
|
289 |
-
|
290 |
-
|
|
|
|
|
291 |
|
292 |
return images[0]
|
293 |
|
|
|
83 |
config=config,
|
84 |
)
|
85 |
|
86 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
87 |
+
images = vq_model.decode_code(gen_token_ids)
|
88 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
89 |
+
images *= 255.0
|
90 |
+
images = images.permute(0, 2, 3, 1).cpu().detach().numpy().astype(np.uint8)
|
|
|
91 |
|
92 |
return images[0]
|
93 |
|
|
|
157 |
config=config,
|
158 |
)
|
159 |
|
160 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
161 |
+
images = vq_model.decode_code(gen_token_ids)
|
162 |
|
163 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
164 |
+
images *= 255.0
|
165 |
+
images = images.permute(0, 2, 3, 1).cpu().detach().numpy().astype(np.uint8)
|
166 |
|
167 |
return images[0]
|
168 |
|
|
|
282 |
|
283 |
_, h, w = gen_token_ids.shape
|
284 |
gen_token_ids = gen_token_ids.reshape(config.training.batch_size, -1)
|
|
|
285 |
|
286 |
+
with torch.no_grad():
|
287 |
+
images = vq_model.decode_code(gen_token_ids, shape=(h, w))
|
288 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
289 |
+
images *= 255.0
|
290 |
+
images = images.permute(0, 2, 3, 1).cpu().detach().numpy().astype(np.uint8)
|
291 |
|
292 |
return images[0]
|
293 |
|