EthanZyh commited on
Commit
4402ae1
·
1 Parent(s): acb64bd

fixed silly bugs. can work now.

Browse files
Files changed (2) hide show
  1. general_dit.py +3 -3
  2. text2world_hf.py +1 -1
general_dit.py CHANGED
@@ -390,16 +390,16 @@ class GeneralDIT(nn.Module):
390
  latent_condition_sigma=latent_condition_sigma,
391
  )
392
  # logging affline scale information
393
- affline_scale_log.info = {}
394
 
395
  timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten())
396
  affline_emb_B_D = timesteps_B_D
397
- affline_scale_log.info["timesteps_B_D"] = timesteps_B_D.detach()
398
 
399
  if scalar_feature is not None:
400
  raise NotImplementedError("Scalar feature is not implemented yet.")
401
 
402
- affline_scale_log.info["affline_emb_B_D"] = affline_emb_B_D.detach()
403
  affline_emb_B_D = self.affline_norm(affline_emb_B_D)
404
 
405
  if self.use_cross_attn_mask:
 
390
  latent_condition_sigma=latent_condition_sigma,
391
  )
392
  # logging affline scale information
393
+ affline_scale_log_info = {}
394
 
395
  timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten())
396
  affline_emb_B_D = timesteps_B_D
397
+ affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach()
398
 
399
  if scalar_feature is not None:
400
  raise NotImplementedError("Scalar feature is not implemented yet.")
401
 
402
+ affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach()
403
  affline_emb_B_D = self.affline_norm(affline_emb_B_D)
404
 
405
  if self.use_cross_attn_mask:
text2world_hf.py CHANGED
@@ -83,7 +83,7 @@ class DiffusionText2World(PreTrainedModel):
83
  prompts = read_prompts_from_file(cfg.batch_input_path)
84
  else:
85
  # Single prompt case
86
- prompts = [{"prompt": cfg.prompt}]
87
 
88
  os.makedirs(cfg.video_save_folder, exist_ok=True)
89
  for i, input_dict in enumerate(prompts):
 
83
  prompts = read_prompts_from_file(cfg.batch_input_path)
84
  else:
85
  # Single prompt case
86
+ prompts = [{"prompt": prompt}]
87
 
88
  os.makedirs(cfg.video_save_folder, exist_ok=True)
89
  for i, input_dict in enumerate(prompts):