JosephBai commited on
Commit
2fad823
·
1 Parent(s): 3c82dd9

fix app torch grad

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -83,12 +83,11 @@ def text_to_image_generation(input_text, guidance_scale=1.75, generation_timeste
83
  config=config,
84
  )
85
 
86
- gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
87
- images = vq_model.decode_code(gen_token_ids)
88
-
89
- images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
90
- images *= 255.0
91
- images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
92
 
93
  return images[0]
94
 
@@ -158,12 +157,12 @@ def text_guided_inpainting(input_text, inpainting_image, inpainting_mask, guidan
158
  config=config,
159
  )
160
 
161
- gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
162
- images = vq_model.decode_code(gen_token_ids)
163
 
164
- images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
165
- images *= 255.0
166
- images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
167
 
168
  return images[0]
169
 
@@ -283,11 +282,12 @@ def text_guided_extrapolation(input_img, input_text, left_ext, right_ext, guidan
283
 
284
  _, h, w = gen_token_ids.shape
285
  gen_token_ids = gen_token_ids.reshape(config.training.batch_size, -1)
286
- images = vq_model.decode_code(gen_token_ids, shape=(h, w))
287
 
288
- images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
289
- images *= 255.0
290
- images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
 
 
291
 
292
  return images[0]
293
 
 
83
  config=config,
84
  )
85
 
86
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
87
+ images = vq_model.decode_code(gen_token_ids)
88
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
89
+ images *= 255.0
90
+ images = images.permute(0, 2, 3, 1).cpu().detach().numpy().astype(np.uint8)
 
91
 
92
  return images[0]
93
 
 
157
  config=config,
158
  )
159
 
160
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
161
+ images = vq_model.decode_code(gen_token_ids)
162
 
163
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
164
+ images *= 255.0
165
+ images = images.permute(0, 2, 3, 1).cpu().detach().numpy().astype(np.uint8)
166
 
167
  return images[0]
168
 
 
282
 
283
  _, h, w = gen_token_ids.shape
284
  gen_token_ids = gen_token_ids.reshape(config.training.batch_size, -1)
 
285
 
286
+ with torch.no_grad():
287
+ images = vq_model.decode_code(gen_token_ids, shape=(h, w))
288
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
289
+ images *= 255.0
290
+ images = images.permute(0, 2, 3, 1).cpu().detach().numpy().astype(np.uint8)
291
 
292
  return images[0]
293