should work now
Browse files- download_diffusion.py +0 -36
- text2world_hf.py +2 -2
download_diffusion.py
CHANGED
@@ -21,42 +21,6 @@ from huggingface_hub import snapshot_download
|
|
21 |
from .convert_pixtral_ckpt import convert_pixtral_checkpoint
|
22 |
|
23 |
|
24 |
-
def parse_args():
|
25 |
-
parser = argparse.ArgumentParser(description="Download NVIDIA Cosmos-1.0 Diffusion models from Hugging Face")
|
26 |
-
parser.add_argument(
|
27 |
-
"--model_sizes",
|
28 |
-
nargs="*",
|
29 |
-
default=[
|
30 |
-
"7B",
|
31 |
-
"14B",
|
32 |
-
], # Download all by default
|
33 |
-
choices=["7B", "14B"],
|
34 |
-
help="Which model sizes to download. Possible values: 7B, 14B",
|
35 |
-
)
|
36 |
-
parser.add_argument(
|
37 |
-
"--model_types",
|
38 |
-
nargs="*",
|
39 |
-
default=[
|
40 |
-
"Text2World",
|
41 |
-
"Video2World",
|
42 |
-
], # Download all by default
|
43 |
-
choices=["Text2World", "Video2World"],
|
44 |
-
help="Which model types to download. Possible values: Text2World, Video2World",
|
45 |
-
)
|
46 |
-
parser.add_argument(
|
47 |
-
"--cosmos_version",
|
48 |
-
type=str,
|
49 |
-
default="1.0",
|
50 |
-
choices=["1.0"],
|
51 |
-
help="Which version of Cosmos to download. Only 1.0 is available at the moment.",
|
52 |
-
)
|
53 |
-
parser.add_argument(
|
54 |
-
"--checkpoint_dir", type=str, default="checkpoints", help="Directory to save the downloaded checkpoints."
|
55 |
-
)
|
56 |
-
args = parser.parse_args()
|
57 |
-
return args
|
58 |
-
|
59 |
-
|
60 |
def main(model_types, model_sizes, checkpoint_dir="checkpoints"):
|
61 |
ORG_NAME = "nvidia"
|
62 |
|
|
|
21 |
from .convert_pixtral_ckpt import convert_pixtral_checkpoint
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def main(model_types, model_sizes, checkpoint_dir="checkpoints"):
|
25 |
ORG_NAME = "nvidia"
|
26 |
|
text2world_hf.py
CHANGED
@@ -48,10 +48,10 @@ class DiffusionText2World(PreTrainedModel):
|
|
48 |
|
49 |
def __init__(self, config=DiffusionText2WorldConfig()):
|
50 |
super().__init__(config)
|
51 |
-
torch.enable_grad(False)
|
52 |
self.config = config
|
53 |
inference_type = "text2world"
|
54 |
-
config.prompt = 1 #
|
55 |
validate_args(config, inference_type)
|
56 |
del config.prompt
|
57 |
self.pipeline = DiffusionText2WorldGenerationPipeline(
|
|
|
48 |
|
49 |
def __init__(self, config=DiffusionText2WorldConfig()):
|
50 |
super().__init__(config)
|
51 |
+
torch.enable_grad(False)
|
52 |
self.config = config
|
53 |
inference_type = "text2world"
|
54 |
+
config.prompt = 1 # this is to hack args validation, maybe find a better way
|
55 |
validate_args(config, inference_type)
|
56 |
del config.prompt
|
57 |
self.pipeline = DiffusionText2WorldGenerationPipeline(
|