EthanZyh commited on
Commit
a0b35da
·
1 Parent(s): f4c7b3e

add some comments

Browse files
Files changed (1) hide show
  1. text2world_hf.py +6 -1
text2world_hf.py CHANGED
@@ -3,12 +3,15 @@ import argparse
3
  import torch
4
  from transformers import PreTrainedModel, PretrainedConfig
5
 
 
6
  from .cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args
7
  from .cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
8
  import .cosmos1.utils.log as log
9
  import .cosmos1.utils.misc as misc
10
  from .cosmos1.utils.io import read_prompts_from_file, save_video
11
 
 
 
12
  class DiffusionText2WorldConfig(PretrainedConfig):
13
  model_type = "DiffusionText2World"
14
  def __init__(self, **kwargs):
@@ -38,6 +41,7 @@ class DiffusionText2WorldConfig(PretrainedConfig):
38
  self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
39
 
40
 
 
41
  class DiffusionText2World(PreTrainedModel):
42
  config_class = DiffusionText2WorldConfig
43
 
@@ -69,6 +73,7 @@ class DiffusionText2World(PreTrainedModel):
69
  seed=config.seed,
70
  )
71
 
 
72
  def forward(self, prompt):
73
  cfg = self.config
74
  # Handle multiple prompts if prompt file is provided
@@ -118,7 +123,7 @@ class DiffusionText2World(PreTrainedModel):
118
  log.info(f"Saved prompt to {prompt_save_path}")
119
 
120
  def save_pretrained(self, save_directory, **kwargs):
121
- # We don't save anything
122
  pass
123
 
124
  @classmethod
 
3
  import torch
4
  from transformers import PreTrainedModel, PretrainedConfig
5
 
6
+ # TODO: This is a bug to fix. Huggingface cannot download .cosmos1.models.diffusion.inference.inference_utils because it's in a subfolder.
7
  from .cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args
8
  from .cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
9
  import .cosmos1.utils.log as log
10
  import .cosmos1.utils.misc as misc
11
  from .cosmos1.utils.io import read_prompts_from_file, save_video
12
 
13
+
14
+ # custom config class
15
  class DiffusionText2WorldConfig(PretrainedConfig):
16
  model_type = "DiffusionText2World"
17
  def __init__(self, **kwargs):
 
41
  self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
42
 
43
 
44
+ # custom model calss
45
  class DiffusionText2World(PreTrainedModel):
46
  config_class = DiffusionText2WorldConfig
47
 
 
73
  seed=config.seed,
74
  )
75
 
76
+ # modifed from text2world.py demo function
77
  def forward(self, prompt):
78
  cfg = self.config
79
  # Handle multiple prompts if prompt file is provided
 
123
  log.info(f"Saved prompt to {prompt_save_path}")
124
 
125
  def save_pretrained(self, save_directory, **kwargs):
126
+ # We don't save anything, but need this function to override
127
  pass
128
 
129
  @classmethod