EthanZyh commited on
Commit
b8232e3
·
1 Parent(s): 6246596

supported downloading ckpt

Browse files
README.md CHANGED
@@ -1,5 +1,5 @@
1
 
2
- ![Cosmos Logo](assets/cosmos-logo.png)
3
 
4
  --------------------------------------------------------------------------------
5
  ### [Website](https://www.nvidia.com/en-us/ai/cosmos/) | [HuggingFace](https://huggingface.co/collections/nvidia/cosmos-6751e884dc10e013a0a0d8e6) | [GPU-free Preview](https://build.nvidia.com/explore/discover) | [Paper](https://arxiv.org/abs/2501.03575) | [Paper Website](https://research.nvidia.com/labs/dir/cosmos1/)
 
1
 
2
+ ![Cosmos Logo](https://github.com/NVIDIA/Cosmos/raw/main/assets/cosmos-logo.png)
3
 
4
  --------------------------------------------------------------------------------
5
  ### [Website](https://www.nvidia.com/en-us/ai/cosmos/) | [HuggingFace](https://huggingface.co/collections/nvidia/cosmos-6751e884dc10e013a0a0d8e6) | [GPU-free Preview](https://build.nvidia.com/explore/discover) | [Paper](https://arxiv.org/abs/2501.03575) | [Paper Website](https://research.nvidia.com/labs/dir/cosmos1/)
cosmos1/scripts/convert_pixtral_ckpt.py → convert_pixtral_ckpt.py RENAMED
File without changes
cosmos1/scripts/download_autoregressive.py → download_autoregressive.py RENAMED
File without changes
cosmos1/scripts/download_diffusion.py → download_diffusion.py RENAMED
@@ -18,7 +18,7 @@ from pathlib import Path
18
 
19
  from huggingface_hub import snapshot_download
20
 
21
- from cosmos1.scripts.convert_pixtral_ckpt import convert_pixtral_checkpoint
22
 
23
 
24
  def parse_args():
@@ -57,7 +57,7 @@ def parse_args():
57
  return args
58
 
59
 
60
- def main(args):
61
  ORG_NAME = "nvidia"
62
 
63
  # Mapping from size argument to Hugging Face repository name
@@ -72,18 +72,18 @@ def main(args):
72
  "Cosmos-1.0-Tokenizer-CV8x8x8",
73
  ]
74
 
75
- if "Text2World" in args.model_types:
76
  extra_models.append("Cosmos-1.0-Prompt-Upsampler-12B-Text2World")
77
 
78
  # Create local checkpoints folder
79
- checkpoints_dir = Path(args.checkpoint_dir)
80
  checkpoints_dir.mkdir(parents=True, exist_ok=True)
81
 
82
  download_kwargs = dict(allow_patterns=["README.md", "model.pt", "config.json", "*.jit"])
83
 
84
  # Download the requested Autoregressive models
85
- for size in args.model_sizes:
86
- for model_type in args.model_types:
87
  suffix = f"-{model_type}"
88
  model_name = model_map[size] + suffix
89
  repo_id = f"{ORG_NAME}/{model_name}"
@@ -109,15 +109,11 @@ def main(args):
109
  local_dir_use_symlinks=False,
110
  )
111
 
112
- if "Video2World" in args.model_types:
113
  # Prompt Upsampler for Cosmos-1.0-Diffusion-Video2World models
114
  convert_pixtral_checkpoint(
115
- checkpoint_dir=args.checkpoint_dir,
116
  checkpoint_name="Pixtral-12B",
117
  vit_type="pixtral-12b-vit",
118
  )
119
 
120
-
121
- if __name__ == "__main__":
122
- args = parse_args()
123
- main(args)
 
18
 
19
  from huggingface_hub import snapshot_download
20
 
21
+ from .convert_pixtral_ckpt import convert_pixtral_checkpoint
22
 
23
 
24
  def parse_args():
 
57
  return args
58
 
59
 
60
+ def main(model_types, model_sizes, checkpoint_dir="checkpoints"):
61
  ORG_NAME = "nvidia"
62
 
63
  # Mapping from size argument to Hugging Face repository name
 
72
  "Cosmos-1.0-Tokenizer-CV8x8x8",
73
  ]
74
 
75
+ if "Text2World" in model_types:
76
  extra_models.append("Cosmos-1.0-Prompt-Upsampler-12B-Text2World")
77
 
78
  # Create local checkpoints folder
79
+ checkpoints_dir = Path(checkpoint_dir)
80
  checkpoints_dir.mkdir(parents=True, exist_ok=True)
81
 
82
  download_kwargs = dict(allow_patterns=["README.md", "model.pt", "config.json", "*.jit"])
83
 
84
  # Download the requested Autoregressive models
85
+ for size in model_sizes:
86
+ for model_type in model_types:
87
  suffix = f"-{model_type}"
88
  model_name = model_map[size] + suffix
89
  repo_id = f"{ORG_NAME}/{model_name}"
 
109
  local_dir_use_symlinks=False,
110
  )
111
 
112
+ if "Video2World" in model_types:
113
  # Prompt Upsampler for Cosmos-1.0-Diffusion-Video2World models
114
  convert_pixtral_checkpoint(
115
+ checkpoint_dir=checkpoint_dir,
116
  checkpoint_name="Pixtral-12B",
117
  vit_type="pixtral-12b-vit",
118
  )
119
 
 
 
 
 
text2world_hf.py CHANGED
@@ -9,6 +9,7 @@ from .log import log
9
  from .misc import misc, Color, timer
10
  from .utils_io import read_prompts_from_file, save_video
11
  from .df_config_config import attrs # this makes huggingface to download the file
 
12
 
13
 
14
  # custom config class
@@ -133,5 +134,9 @@ class DiffusionText2World(PreTrainedModel):
133
  other_args = kwargs.copy()
134
  other_args.pop("config")
135
  config.update(other_args)
 
 
 
 
136
  model = cls(config)
137
  return model
 
9
  from .misc import misc, Color, timer
10
  from .utils_io import read_prompts_from_file, save_video
11
  from .df_config_config import attrs # this makes huggingface to download the file
12
+ from .download_diffusion import main as download_diffusion
13
 
14
 
15
  # custom config class
 
134
  other_args = kwargs.copy()
135
  other_args.pop("config")
136
  config.update(other_args)
137
+ breakpoint()
138
+ model_sizes = ["7B",] if "7B" in config.diffusion_transformer_dir else ["14B",]
139
+ model_types = ["Text2World",]
140
+ download_diffusion(model_types, model_sizes, config.checkpoint_dir)
141
  model = cls(config)
142
  return model