move df_config out
Browse files- cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py +2 -2
- cosmos1/models/autoregressive/diffusion_decoder/network.py +1 -1
- cosmos1/models/autoregressive/inference/world_generation_pipeline.py +1 -1
- cosmos1/models/autoregressive/nemo/inference/general.py +1 -1
- cosmos1/models/diffusion/networks/general_dit_video_conditioned.py +1 -1
- cosmos1/models/diffusion/config/base/model.py β df_base_model.py +1 -1
- cosmos1/models/diffusion/config/base/net.py β df_config_base_net.py +3 -3
- cosmos1/models/diffusion/config/base/tokenizer.py β df_config_base_tokenizer.py +2 -2
- cosmos1/models/diffusion/config/config.py β df_config_config.py +5 -5
- cosmos1/models/diffusion/config/registry.py β df_config_registry.py +3 -3
- cosmos1/models/diffusion/networks/general_dit.py β general_dit.py +4 -4
- text2world_hf.py +1 -0
- world_generation_pipeline.py +2 -2
cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py
CHANGED
@@ -18,8 +18,8 @@ from typing import Any, List
|
|
18 |
import attrs
|
19 |
|
20 |
from cosmos1.models.autoregressive.diffusion_decoder.config.registry import register_configs as register_dd_configs
|
21 |
-
from
|
22 |
-
from
|
23 |
from . import config
|
24 |
from config_helper import import_all_modules_from_package
|
25 |
|
|
|
18 |
import attrs
|
19 |
|
20 |
from cosmos1.models.autoregressive.diffusion_decoder.config.registry import register_configs as register_dd_configs
|
21 |
+
from df_base_model import LatentDiffusionDecoderModelConfig
|
22 |
+
from df_config_registry import register_configs
|
23 |
from . import config
|
24 |
from config_helper import import_all_modules_from_package
|
25 |
|
cosmos1/models/autoregressive/diffusion_decoder/network.py
CHANGED
@@ -21,7 +21,7 @@ from torch import nn
|
|
21 |
from torchvision import transforms
|
22 |
|
23 |
from blocks import PatchEmbed
|
24 |
-
from
|
25 |
|
26 |
|
27 |
class DiffusionDecoderGeneralDIT(GeneralDIT):
|
|
|
21 |
from torchvision import transforms
|
22 |
|
23 |
from blocks import PatchEmbed
|
24 |
+
from general_dit import GeneralDIT
|
25 |
|
26 |
|
27 |
class DiffusionDecoderGeneralDIT(GeneralDIT):
|
cosmos1/models/autoregressive/inference/world_generation_pipeline.py
CHANGED
@@ -40,7 +40,7 @@ from inference_utils import (
|
|
40 |
load_network_model,
|
41 |
load_tokenizer_model,
|
42 |
)
|
43 |
-
from . import misc
|
44 |
|
45 |
|
46 |
def detect_model_size_from_ckpt_path(ckpt_path: str) -> str:
|
|
|
40 |
load_network_model,
|
41 |
load_tokenizer_model,
|
42 |
)
|
43 |
+
from .misc import misc, Color, timer
|
44 |
|
45 |
|
46 |
def detect_model_size_from_ckpt_path(ckpt_path: str) -> str:
|
cosmos1/models/autoregressive/nemo/inference/general.py
CHANGED
@@ -36,7 +36,7 @@ from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
|
|
36 |
from cosmos1.models.autoregressive.nemo.utils import run_diffusion_decoder_model
|
37 |
from discrete_video import DiscreteVideoFSQJITTokenizer
|
38 |
from cosmos1.models.autoregressive.utils.inference import load_vision_input
|
39 |
-
from . import presets as guardrail_presets
|
40 |
from .log import log
|
41 |
|
42 |
torch._C._jit_set_texpr_fuser_enabled(False)
|
|
|
36 |
from cosmos1.models.autoregressive.nemo.utils import run_diffusion_decoder_model
|
37 |
from discrete_video import DiscreteVideoFSQJITTokenizer
|
38 |
from cosmos1.models.autoregressive.utils.inference import load_vision_input
|
39 |
+
from .presets import presets as guardrail_presets
|
40 |
from .log import log
|
41 |
|
42 |
torch._C._jit_set_texpr_fuser_enabled(False)
|
cosmos1/models/diffusion/networks/general_dit_video_conditioned.py
CHANGED
@@ -21,7 +21,7 @@ from torch import nn
|
|
21 |
|
22 |
from conditioner import DataType
|
23 |
from blocks import TimestepEmbedding, Timesteps
|
24 |
-
from
|
25 |
from .log import log
|
26 |
|
27 |
|
|
|
21 |
|
22 |
from conditioner import DataType
|
23 |
from blocks import TimestepEmbedding, Timesteps
|
24 |
+
from general_dit import GeneralDIT
|
25 |
from .log import log
|
26 |
|
27 |
|
cosmos1/models/diffusion/config/base/model.py β df_base_model.py
RENAMED
@@ -17,7 +17,7 @@ from typing import List
|
|
17 |
|
18 |
import attrs
|
19 |
|
20 |
-
from lazy_config_init import LazyDict
|
21 |
|
22 |
|
23 |
@attrs.define(slots=False)
|
|
|
17 |
|
18 |
import attrs
|
19 |
|
20 |
+
from .lazy_config_init import LazyDict
|
21 |
|
22 |
|
23 |
@attrs.define(slots=False)
|
cosmos1/models/diffusion/config/base/net.py β df_config_base_net.py
RENAMED
@@ -15,9 +15,9 @@
|
|
15 |
|
16 |
import copy
|
17 |
|
18 |
-
from
|
19 |
-
from lazy_config_init import LazyCall as L
|
20 |
-
from lazy_config_init import LazyDict
|
21 |
|
22 |
FADITV2Config: LazyDict = L(GeneralDIT)(
|
23 |
max_img_h=240,
|
|
|
15 |
|
16 |
import copy
|
17 |
|
18 |
+
from .general_dit import GeneralDIT
|
19 |
+
from .lazy_config_init import LazyCall as L
|
20 |
+
from .lazy_config_init import LazyDict
|
21 |
|
22 |
FADITV2Config: LazyDict = L(GeneralDIT)(
|
23 |
max_img_h=240,
|
cosmos1/models/diffusion/config/base/tokenizer.py β df_config_base_tokenizer.py
RENAMED
@@ -15,8 +15,8 @@
|
|
15 |
|
16 |
import omegaconf
|
17 |
|
18 |
-
from pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer
|
19 |
-
from lazy_config_init import LazyCall as L
|
20 |
|
21 |
TOKENIZER_OPTIONS = {}
|
22 |
|
|
|
15 |
|
16 |
import omegaconf
|
17 |
|
18 |
+
from .pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer
|
19 |
+
from .lazy_config_init import LazyCall as L
|
20 |
|
21 |
TOKENIZER_OPTIONS = {}
|
22 |
|
cosmos1/models/diffusion/config/config.py β df_config_config.py
RENAMED
@@ -17,14 +17,14 @@ from typing import Any, List
|
|
17 |
|
18 |
import attrs
|
19 |
|
20 |
-
from
|
21 |
-
from
|
22 |
-
from . import
|
23 |
-
from config_helper import import_all_modules_from_package
|
24 |
|
25 |
|
26 |
@attrs.define(slots=False)
|
27 |
-
class Config(
|
28 |
# default config groups that will be used unless overwritten
|
29 |
# see config groups in registry.py
|
30 |
defaults: List[Any] = attrs.field(
|
|
|
17 |
|
18 |
import attrs
|
19 |
|
20 |
+
from .df_base_model import DefaultModelConfig
|
21 |
+
from .df_config_registry import register_configs
|
22 |
+
from .config import Config as ori_Config
|
23 |
+
from .config_helper import import_all_modules_from_package
|
24 |
|
25 |
|
26 |
@attrs.define(slots=False)
|
27 |
+
class Config(ori_Config):
|
28 |
# default config groups that will be used unless overwritten
|
29 |
# see config groups in registry.py
|
30 |
defaults: List[Any] = attrs.field(
|
cosmos1/models/diffusion/config/registry.py β df_config_registry.py
RENAMED
@@ -15,13 +15,13 @@
|
|
15 |
|
16 |
from hydra.core.config_store import ConfigStore
|
17 |
|
18 |
-
from config_base_conditioner import (
|
19 |
BaseVideoConditionerConfig,
|
20 |
VideoConditionerFpsSizePaddingConfig,
|
21 |
VideoExtendConditionerConfig,
|
22 |
)
|
23 |
-
from
|
24 |
-
from
|
25 |
|
26 |
|
27 |
def register_net(cs):
|
|
|
15 |
|
16 |
from hydra.core.config_store import ConfigStore
|
17 |
|
18 |
+
from .config_base_conditioner import (
|
19 |
BaseVideoConditionerConfig,
|
20 |
VideoConditionerFpsSizePaddingConfig,
|
21 |
VideoExtendConditionerConfig,
|
22 |
)
|
23 |
+
from .df_config_base_net import FADITV2_14B_Config, FADITV2Config
|
24 |
+
from .df_config_base_tokenizer import get_cosmos_diffusion_tokenizer_comp8x8x8
|
25 |
|
26 |
|
27 |
def register_net(cs):
|
cosmos1/models/diffusion/networks/general_dit.py β general_dit.py
RENAMED
@@ -24,16 +24,16 @@ from einops import rearrange
|
|
24 |
from torch import nn
|
25 |
from torchvision import transforms
|
26 |
|
27 |
-
from conditioner import DataType
|
28 |
-
from attention import get_normalization
|
29 |
-
from blocks import (
|
30 |
FinalLayer,
|
31 |
GeneralDITTransformerBlock,
|
32 |
PatchEmbed,
|
33 |
TimestepEmbedding,
|
34 |
Timesteps,
|
35 |
)
|
36 |
-
from position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
|
37 |
from .log import log
|
38 |
|
39 |
|
|
|
24 |
from torch import nn
|
25 |
from torchvision import transforms
|
26 |
|
27 |
+
from .conditioner import DataType
|
28 |
+
from .attention import get_normalization
|
29 |
+
from .blocks import (
|
30 |
FinalLayer,
|
31 |
GeneralDITTransformerBlock,
|
32 |
PatchEmbed,
|
33 |
TimestepEmbedding,
|
34 |
Timesteps,
|
35 |
)
|
36 |
+
from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
|
37 |
from .log import log
|
38 |
|
39 |
|
text2world_hf.py
CHANGED
@@ -8,6 +8,7 @@ from .world_generation_pipeline import DiffusionText2WorldGenerationPipeline
|
|
8 |
from .log import log
|
9 |
from .misc import misc, Color, timer
|
10 |
from .utils_io import read_prompts_from_file, save_video
|
|
|
11 |
|
12 |
|
13 |
# custom config class
|
|
|
8 |
from .log import log
|
9 |
from .misc import misc, Color, timer
|
10 |
from .utils_io import read_prompts_from_file, save_video
|
11 |
+
from .df_config_config import attrs # this makes huggingface to download the file
|
12 |
|
13 |
|
14 |
# custom config class
|
world_generation_pipeline.py
CHANGED
@@ -140,7 +140,7 @@ class DiffusionText2WorldGenerationPipeline(BaseWorldGenerationPipeline):
|
|
140 |
def _load_model(self):
|
141 |
self.model = load_model_by_config(
|
142 |
config_job_name=self.model_name,
|
143 |
-
config_file="
|
144 |
model_class=DiffusionT2WModel,
|
145 |
)
|
146 |
|
@@ -468,7 +468,7 @@ class DiffusionVideo2WorldGenerationPipeline(DiffusionText2WorldGenerationPipeli
|
|
468 |
def _load_model(self):
|
469 |
self.model = load_model_by_config(
|
470 |
config_job_name=self.model_name,
|
471 |
-
config_file="
|
472 |
model_class=DiffusionV2WModel,
|
473 |
)
|
474 |
|
|
|
140 |
def _load_model(self):
|
141 |
self.model = load_model_by_config(
|
142 |
config_job_name=self.model_name,
|
143 |
+
config_file="df_config_config.py",
|
144 |
model_class=DiffusionT2WModel,
|
145 |
)
|
146 |
|
|
|
468 |
def _load_model(self):
|
469 |
self.model = load_model_by_config(
|
470 |
config_job_name=self.model_name,
|
471 |
+
config_file="df_config_config.py",
|
472 |
model_class=DiffusionV2WModel,
|
473 |
)
|
474 |
|