add some comments
Browse files- text2world_hf.py +6 -1
text2world_hf.py
CHANGED
@@ -3,12 +3,15 @@ import argparse
|
|
3 |
import torch
|
4 |
from transformers import PreTrainedModel, PretrainedConfig
|
5 |
|
|
|
6 |
from .cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args
|
7 |
from .cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
|
8 |
import .cosmos1.utils.log as log
|
9 |
import .cosmos1.utils.misc as misc
|
10 |
from .cosmos1.utils.io import read_prompts_from_file, save_video
|
11 |
|
|
|
|
|
12 |
class DiffusionText2WorldConfig(PretrainedConfig):
|
13 |
model_type = "DiffusionText2World"
|
14 |
def __init__(self, **kwargs):
|
@@ -38,6 +41,7 @@ class DiffusionText2WorldConfig(PretrainedConfig):
|
|
38 |
self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
|
39 |
|
40 |
|
|
|
41 |
class DiffusionText2World(PreTrainedModel):
|
42 |
config_class = DiffusionText2WorldConfig
|
43 |
|
@@ -69,6 +73,7 @@ class DiffusionText2World(PreTrainedModel):
|
|
69 |
seed=config.seed,
|
70 |
)
|
71 |
|
|
|
72 |
def forward(self, prompt):
|
73 |
cfg = self.config
|
74 |
# Handle multiple prompts if prompt file is provided
|
@@ -118,7 +123,7 @@ class DiffusionText2World(PreTrainedModel):
|
|
118 |
log.info(f"Saved prompt to {prompt_save_path}")
|
119 |
|
120 |
def save_pretrained(self, save_directory, **kwargs):
|
121 |
-
# We don't save anything
|
122 |
pass
|
123 |
|
124 |
@classmethod
|
|
|
3 |
import torch
|
4 |
from transformers import PreTrainedModel, PretrainedConfig
|
5 |
|
6 |
+
# TODO: This is a bug to fix. Huggingface cannot download .cosmos1.models.diffusion.inference.inference_utils because it's in a subfolder.
|
7 |
from .cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args
|
8 |
from .cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
|
9 |
import .cosmos1.utils.log as log
|
10 |
import .cosmos1.utils.misc as misc
|
11 |
from .cosmos1.utils.io import read_prompts_from_file, save_video
|
12 |
|
13 |
+
|
14 |
+
# custom config class
|
15 |
class DiffusionText2WorldConfig(PretrainedConfig):
|
16 |
model_type = "DiffusionText2World"
|
17 |
def __init__(self, **kwargs):
|
|
|
41 |
self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
|
42 |
|
43 |
|
44 |
+
# custom model calss
|
45 |
class DiffusionText2World(PreTrainedModel):
|
46 |
config_class = DiffusionText2WorldConfig
|
47 |
|
|
|
73 |
seed=config.seed,
|
74 |
)
|
75 |
|
76 |
+
# modifed from text2world.py demo function
|
77 |
def forward(self, prompt):
|
78 |
cfg = self.config
|
79 |
# Handle multiple prompts if prompt file is provided
|
|
|
123 |
log.info(f"Saved prompt to {prompt_save_path}")
|
124 |
|
125 |
def save_pretrained(self, save_directory, **kwargs):
|
126 |
+
# We don't save anything, but need this function to override
|
127 |
pass
|
128 |
|
129 |
@classmethod
|