Fixed config loading
Browse files- .gitignore +2 -1
- cosmos1/models/autoregressive/diffusion_decoder/config/registry.py → ar_diffusion_decoder_config_registry.py +0 -0
- cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py → config_latent_diffusion_decoder.py +2 -2
- inference_utils.py +5 -2
- ldd_config.py +61 -0
.gitignore
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
huggingface.txt
|
|
|
|
1 |
+
huggingface.txt
|
2 |
+
checkpoints/
|
cosmos1/models/autoregressive/diffusion_decoder/config/registry.py → ar_diffusion_decoder_config_registry.py
RENAMED
File without changes
|
cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py → config_latent_diffusion_decoder.py
RENAMED
@@ -17,10 +17,10 @@ from typing import Any, List
|
|
17 |
|
18 |
import attrs
|
19 |
|
20 |
-
from cosmos1.models.autoregressive.diffusion_decoder.config.
|
21 |
from df_base_model import LatentDiffusionDecoderModelConfig
|
22 |
from df_config_registry import register_configs
|
23 |
-
from . import config
|
24 |
from config_helper import import_all_modules_from_package
|
25 |
|
26 |
|
|
|
17 |
|
18 |
import attrs
|
19 |
|
20 |
+
from cosmos1.models.autoregressive.diffusion_decoder.config.ar_diffusion_decoder_config_registry import register_configs as register_dd_configs
|
21 |
from df_base_model import LatentDiffusionDecoderModelConfig
|
22 |
from df_config_registry import register_configs
|
23 |
+
from .cosmos1.models.autoregressive.diffusion_decoder.config import config
|
24 |
from config_helper import import_all_modules_from_package
|
25 |
|
26 |
|
inference_utils.py
CHANGED
@@ -26,10 +26,12 @@ import torchvision.transforms.functional as transforms_F
|
|
26 |
|
27 |
from .model_t2w import DiffusionT2WModel
|
28 |
from .model_v2w import DiffusionV2WModel
|
|
|
29 |
from .config_helper import get_config_module, override
|
30 |
from .utils_io import load_from_fileobj
|
31 |
from .misc import misc
|
32 |
from .df_config_config import make_config
|
|
|
33 |
from .log import log
|
34 |
|
35 |
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
|
@@ -279,8 +281,9 @@ def load_model_by_config(
|
|
279 |
# config = importlib.import_module(config_module).make_config()
|
280 |
if model_class in (DiffusionT2WModel, DiffusionV2WModel):
|
281 |
config = make_config()
|
282 |
-
|
283 |
-
|
|
|
284 |
|
285 |
config = override(config, ["--", f"experiment={config_job_name}"])
|
286 |
|
|
|
26 |
|
27 |
from .model_t2w import DiffusionT2WModel
|
28 |
from .model_v2w import DiffusionV2WModel
|
29 |
+
from .ar_diffusion_decoder_model import LatentDiffusionDecoderModel
|
30 |
from .config_helper import get_config_module, override
|
31 |
from .utils_io import load_from_fileobj
|
32 |
from .misc import misc
|
33 |
from .df_config_config import make_config
|
34 |
+
from .ldd_config import make_config as ldd_make_config
|
35 |
from .log import log
|
36 |
|
37 |
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
|
|
|
281 |
# config = importlib.import_module(config_module).make_config()
|
282 |
if model_class in (DiffusionT2WModel, DiffusionV2WModel):
|
283 |
config = make_config()
|
284 |
+
elif model_class in (LatentDiffusionDecoderModel):
|
285 |
+
config = ldd_make_config()
|
286 |
+
|
287 |
|
288 |
config = override(config, ["--", f"experiment={config_job_name}"])
|
289 |
|
ldd_config.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Any, List
|
17 |
+
|
18 |
+
import attrs
|
19 |
+
|
20 |
+
from .ar_diffusion_decoder_config_registry import register_configs as register_dd_configs
|
21 |
+
from .df_base_model import LatentDiffusionDecoderModelConfig
|
22 |
+
from .df_config_registry import register_configs
|
23 |
+
from .config import Config as ori_Config
|
24 |
+
from .config_helper import import_all_modules_from_package
|
25 |
+
|
26 |
+
|
27 |
+
@attrs.define(slots=False)
|
28 |
+
class Config(ori_Config):
|
29 |
+
# default config groups that will be used unless overwritten
|
30 |
+
# see config groups in registry.py
|
31 |
+
defaults: List[Any] = attrs.field(
|
32 |
+
factory=lambda: [
|
33 |
+
"_self_",
|
34 |
+
{"net": None},
|
35 |
+
{"conditioner": "basic"},
|
36 |
+
{"tokenizer": "tokenizer"},
|
37 |
+
{"tokenizer_corruptor": None},
|
38 |
+
{"latent_corruptor": None},
|
39 |
+
{"pixel_corruptor": None},
|
40 |
+
{"experiment": None},
|
41 |
+
]
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def make_config():
|
46 |
+
c = Config(model=LatentDiffusionDecoderModelConfig())
|
47 |
+
|
48 |
+
# Specifying values through instances of attrs
|
49 |
+
c.job.project = "cosmos_video4"
|
50 |
+
c.job.group = "debug"
|
51 |
+
c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}"
|
52 |
+
|
53 |
+
# Call this function to register config groups for advanced overriding.
|
54 |
+
register_configs()
|
55 |
+
register_dd_configs()
|
56 |
+
|
57 |
+
# experiment config are defined in the experiment folder
|
58 |
+
# call import_all_modules_from_package to register them
|
59 |
+
import_all_modules_from_package("cosmos1.models.diffusion.config.inference", reload=True)
|
60 |
+
import_all_modules_from_package("cosmos1.models.autoregressive.diffusion_decoder.config.inference", reload=True)
|
61 |
+
return c
|