EthanZyh commited on
Commit
97c7ba6
·
1 Parent(s): 4402ae1

should work now

Browse files
Files changed (2) hide show
  1. download_diffusion.py +0 -36
  2. text2world_hf.py +2 -2
download_diffusion.py CHANGED
@@ -21,42 +21,6 @@ from huggingface_hub import snapshot_download
21
  from .convert_pixtral_ckpt import convert_pixtral_checkpoint
22
 
23
 
24
- def parse_args():
25
- parser = argparse.ArgumentParser(description="Download NVIDIA Cosmos-1.0 Diffusion models from Hugging Face")
26
- parser.add_argument(
27
- "--model_sizes",
28
- nargs="*",
29
- default=[
30
- "7B",
31
- "14B",
32
- ], # Download all by default
33
- choices=["7B", "14B"],
34
- help="Which model sizes to download. Possible values: 7B, 14B",
35
- )
36
- parser.add_argument(
37
- "--model_types",
38
- nargs="*",
39
- default=[
40
- "Text2World",
41
- "Video2World",
42
- ], # Download all by default
43
- choices=["Text2World", "Video2World"],
44
- help="Which model types to download. Possible values: Text2World, Video2World",
45
- )
46
- parser.add_argument(
47
- "--cosmos_version",
48
- type=str,
49
- default="1.0",
50
- choices=["1.0"],
51
- help="Which version of Cosmos to download. Only 1.0 is available at the moment.",
52
- )
53
- parser.add_argument(
54
- "--checkpoint_dir", type=str, default="checkpoints", help="Directory to save the downloaded checkpoints."
55
- )
56
- args = parser.parse_args()
57
- return args
58
-
59
-
60
  def main(model_types, model_sizes, checkpoint_dir="checkpoints"):
61
  ORG_NAME = "nvidia"
62
 
 
21
  from .convert_pixtral_ckpt import convert_pixtral_checkpoint
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def main(model_types, model_sizes, checkpoint_dir="checkpoints"):
25
  ORG_NAME = "nvidia"
26
 
text2world_hf.py CHANGED
@@ -48,10 +48,10 @@ class DiffusionText2World(PreTrainedModel):
48
 
49
  def __init__(self, config=DiffusionText2WorldConfig()):
50
  super().__init__(config)
51
- torch.enable_grad(False) # TODO: do we need this?
52
  self.config = config
53
  inference_type = "text2world"
54
- config.prompt = 1 # TODO: this is to hack args validation, maybe find a better way
55
  validate_args(config, inference_type)
56
  del config.prompt
57
  self.pipeline = DiffusionText2WorldGenerationPipeline(
 
48
 
49
  def __init__(self, config=DiffusionText2WorldConfig()):
50
  super().__init__(config)
51
+ torch.enable_grad(False)
52
  self.config = config
53
  inference_type = "text2world"
54
+ config.prompt = 1 # this is to hack args validation, maybe find a better way
55
  validate_args(config, inference_type)
56
  del config.prompt
57
  self.pipeline = DiffusionText2WorldGenerationPipeline(