sheldonl commited on
Commit
4783839
·
1 Parent(s): 9320716

Fixed config loading

Browse files
.gitignore CHANGED
@@ -1 +1,2 @@
1
- huggingface.txt
 
 
1
+ huggingface.txt
2
+ checkpoints/
cosmos1/models/autoregressive/diffusion_decoder/config/registry.py → ar_diffusion_decoder_config_registry.py RENAMED
File without changes
cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py → config_latent_diffusion_decoder.py RENAMED
@@ -17,10 +17,10 @@ from typing import Any, List
17
 
18
  import attrs
19
 
20
- from cosmos1.models.autoregressive.diffusion_decoder.config.registry import register_configs as register_dd_configs
21
  from df_base_model import LatentDiffusionDecoderModelConfig
22
  from df_config_registry import register_configs
23
- from . import config
24
  from config_helper import import_all_modules_from_package
25
 
26
 
 
17
 
18
  import attrs
19
 
20
+ from cosmos1.models.autoregressive.diffusion_decoder.config.ar_diffusion_decoder_config_registry import register_configs as register_dd_configs
21
  from df_base_model import LatentDiffusionDecoderModelConfig
22
  from df_config_registry import register_configs
23
+ from .cosmos1.models.autoregressive.diffusion_decoder.config import config
24
  from config_helper import import_all_modules_from_package
25
 
26
 
inference_utils.py CHANGED
@@ -26,10 +26,12 @@ import torchvision.transforms.functional as transforms_F
26
 
27
  from .model_t2w import DiffusionT2WModel
28
  from .model_v2w import DiffusionV2WModel
 
29
  from .config_helper import get_config_module, override
30
  from .utils_io import load_from_fileobj
31
  from .misc import misc
32
  from .df_config_config import make_config
 
33
  from .log import log
34
 
35
  TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
@@ -279,8 +281,9 @@ def load_model_by_config(
279
  # config = importlib.import_module(config_module).make_config()
280
  if model_class in (DiffusionT2WModel, DiffusionV2WModel):
281
  config = make_config()
282
- else:
283
- raise NotImplementedError("TODO: didn't implement autoregression")
 
284
 
285
  config = override(config, ["--", f"experiment={config_job_name}"])
286
 
 
26
 
27
  from .model_t2w import DiffusionT2WModel
28
  from .model_v2w import DiffusionV2WModel
29
+ from .ar_diffusion_decoder_model import LatentDiffusionDecoderModel
30
  from .config_helper import get_config_module, override
31
  from .utils_io import load_from_fileobj
32
  from .misc import misc
33
  from .df_config_config import make_config
34
+ from .ldd_config import make_config as ldd_make_config
35
  from .log import log
36
 
37
  TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
 
281
  # config = importlib.import_module(config_module).make_config()
282
  if model_class in (DiffusionT2WModel, DiffusionV2WModel):
283
  config = make_config()
284
+ elif model_class in (LatentDiffusionDecoderModel):
285
+ config = ldd_make_config()
286
+
287
 
288
  config = override(config, ["--", f"experiment={config_job_name}"])
289
 
ldd_config.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, List
17
+
18
+ import attrs
19
+
20
+ from .ar_diffusion_decoder_config_registry import register_configs as register_dd_configs
21
+ from .df_base_model import LatentDiffusionDecoderModelConfig
22
+ from .df_config_registry import register_configs
23
+ from .config import Config as ori_Config
24
+ from .config_helper import import_all_modules_from_package
25
+
26
+
27
+ @attrs.define(slots=False)
28
+ class Config(ori_Config):
29
+ # default config groups that will be used unless overwritten
30
+ # see config groups in registry.py
31
+ defaults: List[Any] = attrs.field(
32
+ factory=lambda: [
33
+ "_self_",
34
+ {"net": None},
35
+ {"conditioner": "basic"},
36
+ {"tokenizer": "tokenizer"},
37
+ {"tokenizer_corruptor": None},
38
+ {"latent_corruptor": None},
39
+ {"pixel_corruptor": None},
40
+ {"experiment": None},
41
+ ]
42
+ )
43
+
44
+
45
+ def make_config():
46
+ c = Config(model=LatentDiffusionDecoderModelConfig())
47
+
48
+ # Specifying values through instances of attrs
49
+ c.job.project = "cosmos_video4"
50
+ c.job.group = "debug"
51
+ c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}"
52
+
53
+ # Call this function to register config groups for advanced overriding.
54
+ register_configs()
55
+ register_dd_configs()
56
+
57
+ # experiment config are defined in the experiment folder
58
+ # call import_all_modules_from_package to register them
59
+ import_all_modules_from_package("cosmos1.models.diffusion.config.inference", reload=True)
60
+ import_all_modules_from_package("cosmos1.models.autoregressive.diffusion_decoder.config.inference", reload=True)
61
+ return c