EthanZyh commited on
Commit
d73f308
·
1 Parent(s): a215d9f

add cosmos-1-diffusion-text2world

Browse files
.gitignore CHANGED
@@ -16,7 +16,9 @@
16
  # Misc
17
  outputs/
18
  checkpoints/*
19
- !checkpoints/README.md
 
 
20
 
21
  # Data types
22
  *.jit
 
16
  # Misc
17
  outputs/
18
  checkpoints/*
19
+ checkpoints/README.md
20
+ checkpoints
21
+ .gitignore
22
 
23
  # Data types
24
  *.jit
config_helper.py CHANGED
@@ -29,6 +29,7 @@ from omegaconf import DictConfig, OmegaConf
29
 
30
  from .log import log
31
  from .config import Config
 
32
 
33
 
34
  def is_attrs_or_dataclass(obj) -> bool:
@@ -163,6 +164,7 @@ def import_all_modules_from_package(package_path: str, reload: bool = False, ski
163
  reload (bool): Flag to determine whether to reload modules if they're already imported.
164
  skip_underscore (bool): If True, skips importing modules that start with an underscore.
165
  """
 
166
  log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
167
  package = importlib.import_module(package_path)
168
  package_directory = package.__path__
 
29
 
30
  from .log import log
31
  from .config import Config
32
+ from .inference import *
33
 
34
 
35
  def is_attrs_or_dataclass(obj) -> bool:
 
164
  reload (bool): Flag to determine whether to reload modules if they're already imported.
165
  skip_underscore (bool): If True, skips importing modules that start with an underscore.
166
  """
167
+ return # TODO: we do not use this
168
  log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
169
  package = importlib.import_module(package_path)
170
  package_directory = package.__path__
cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-text2world.py → cosmos1diffusiontext2world.py RENAMED
File without changes
cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-video2world.py → cosmos1diffusionvideo2world.py RENAMED
File without changes
df_config_config.py CHANGED
@@ -22,6 +22,9 @@ from .df_config_registry import register_configs
22
  from .config import Config as ori_Config
23
  from .config_helper import import_all_modules_from_package
24
 
 
 
 
25
 
26
  @attrs.define(slots=False)
27
  class Config(ori_Config):
 
22
  from .config import Config as ori_Config
23
  from .config_helper import import_all_modules_from_package
24
 
25
+ # I added importing here
26
+ from .cosmos1diffusiontext2world import LazyDict
27
+ from .cosmos1diffusionvideo2world import LazyDict
28
 
29
  @attrs.define(slots=False)
30
  class Config(ori_Config):
inference_utils.py CHANGED
@@ -29,6 +29,8 @@ from .model_v2w import DiffusionV2WModel
29
  from .config_helper import get_config_module, override
30
  from .utils_io import load_from_fileobj
31
  from .misc import misc
 
 
32
 
33
  TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
34
  if TORCH_VERSION >= (1, 11):
@@ -272,8 +274,13 @@ def load_model_by_config(
272
  config_file="projects/cosmos_video/config/config.py",
273
  model_class=DiffusionT2WModel,
274
  ):
275
- config_module = get_config_module(config_file)
276
- config = importlib.import_module(config_module).make_config()
 
 
 
 
 
277
 
278
  config = override(config, ["--", f"experiment={config_job_name}"])
279
 
 
29
  from .config_helper import get_config_module, override
30
  from .utils_io import load_from_fileobj
31
  from .misc import misc
32
+ from .df_config_config import make_config
33
+ from .log import log
34
 
35
  TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
36
  if TORCH_VERSION >= (1, 11):
 
274
  config_file="projects/cosmos_video/config/config.py",
275
  model_class=DiffusionT2WModel,
276
  ):
277
+ # TODO: We need to modify this for huggingface because the config file path is different
278
+ # config_module = get_config_module(config_file)
279
+ # config = importlib.import_module(config_module).make_config()
280
+ if model_class in (DiffusionT2WModel, DiffusionV2WModel):
281
+ config = make_config()
282
+ else:
283
+ raise NotImplementedError("TODO: didn't implement autoregression")
284
 
285
  config = override(config, ["--", f"experiment={config_job_name}"])
286