Finished rearranging model and config files
Browse files- .gitignore +1 -0
- model_config.py β ar_configs_model_config.py +0 -0
- cosmos1/models/autoregressive/diffusion_decoder/inference.py β ar_diffusion_decoder_inference.py +3 -3
- cosmos1/models/autoregressive/diffusion_decoder/model.py β ar_diffusion_decoder_model.py +5 -5
- cosmos1/models/autoregressive/diffusion_decoder/utils.py β ar_diffusion_decoder_utils.py +0 -0
- cosmos1/models/autoregressive/inference/world_generation_pipeline.py +3 -3
- cosmos1/models/autoregressive/nemo/utils.py +2 -2
- futureworld_hf.py +29 -16
- text2world_prompt_upsampler_inference.py +1 -1
- video2world_prompt_upsampler_inference.py +1 -1
- world_generation_pipeline.py +11 -10
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
huggingface.txt
|
model_config.py β ar_configs_model_config.py
RENAMED
File without changes
|
cosmos1/models/autoregressive/diffusion_decoder/inference.py β ar_diffusion_decoder_inference.py
RENAMED
@@ -19,9 +19,9 @@ from typing import List
|
|
19 |
|
20 |
import torch
|
21 |
|
22 |
-
from inference_config import DiffusionDecoderSamplingConfig
|
23 |
-
from
|
24 |
-
from
|
25 |
from .log import log
|
26 |
|
27 |
|
|
|
19 |
|
20 |
import torch
|
21 |
|
22 |
+
from .inference_config import DiffusionDecoderSamplingConfig
|
23 |
+
from .ar_diffusion_decoder_model import LatentDiffusionDecoderModel
|
24 |
+
from .ar_diffusion_decoder_utils import linear_blend_video_list, split_with_overlap
|
25 |
from .log import log
|
26 |
|
27 |
|
cosmos1/models/autoregressive/diffusion_decoder/model.py β ar_diffusion_decoder_model.py
RENAMED
@@ -19,11 +19,11 @@ from typing import Callable, Dict, Optional, Tuple
|
|
19 |
import torch
|
20 |
from torch import Tensor
|
21 |
|
22 |
-
from conditioner import BaseVideoCondition
|
23 |
-
from batch_ops import batch_mul
|
24 |
-
from res_sampler import COMMON_SOLVER_OPTIONS
|
25 |
-
from model_t2w import DiffusionT2WModel as VideoDiffusionModel
|
26 |
-
from lazy_config_init import instantiate as lazy_instantiate
|
27 |
|
28 |
|
29 |
@dataclass
|
|
|
19 |
import torch
|
20 |
from torch import Tensor
|
21 |
|
22 |
+
from .conditioner import BaseVideoCondition
|
23 |
+
from .batch_ops import batch_mul
|
24 |
+
from .res_sampler import COMMON_SOLVER_OPTIONS
|
25 |
+
from .model_t2w import DiffusionT2WModel as VideoDiffusionModel
|
26 |
+
from .lazy_config_init import instantiate as lazy_instantiate
|
27 |
|
28 |
|
29 |
@dataclass
|
cosmos1/models/autoregressive/diffusion_decoder/utils.py β ar_diffusion_decoder_utils.py
RENAMED
File without changes
|
cosmos1/models/autoregressive/inference/world_generation_pipeline.py
CHANGED
@@ -22,7 +22,7 @@ import numpy as np
|
|
22 |
import torch
|
23 |
from einops import rearrange
|
24 |
|
25 |
-
from
|
26 |
from ar_config_tokenizer import TokenizerConfig
|
27 |
from inference_config import (
|
28 |
DataShapeConfig,
|
@@ -30,8 +30,8 @@ from inference_config import (
|
|
30 |
InferenceConfig,
|
31 |
SamplingConfig,
|
32 |
)
|
33 |
-
from cosmos1.models.autoregressive.diffusion_decoder.
|
34 |
-
from cosmos1.models.autoregressive.diffusion_decoder.
|
35 |
from ar_model import AutoRegressiveModel
|
36 |
from cosmos1.models.autoregressive.utils.inference import _SUPPORTED_CONTEXT_LEN, prepare_video_batch_for_saving
|
37 |
from base_world_generation_pipeline import BaseWorldGenerationPipeline
|
|
|
22 |
import torch
|
23 |
from einops import rearrange
|
24 |
|
25 |
+
from ar_configs_model_config import create_video2world_model_config
|
26 |
from ar_config_tokenizer import TokenizerConfig
|
27 |
from inference_config import (
|
28 |
DataShapeConfig,
|
|
|
30 |
InferenceConfig,
|
31 |
SamplingConfig,
|
32 |
)
|
33 |
+
from cosmos1.models.autoregressive.diffusion_decoder.ar_diffusion_decoder_inference import diffusion_decoder_process_tokens
|
34 |
+
from cosmos1.models.autoregressive.diffusion_decoder.ar_diffusion_decoder_model import LatentDiffusionDecoderModel
|
35 |
from ar_model import AutoRegressiveModel
|
36 |
from cosmos1.models.autoregressive.utils.inference import _SUPPORTED_CONTEXT_LEN, prepare_video_batch_for_saving
|
37 |
from base_world_generation_pipeline import BaseWorldGenerationPipeline
|
cosmos1/models/autoregressive/nemo/utils.py
CHANGED
@@ -24,8 +24,8 @@ import torchvision
|
|
24 |
from huggingface_hub import snapshot_download
|
25 |
|
26 |
from inference_config import DiffusionDecoderSamplingConfig
|
27 |
-
from cosmos1.models.autoregressive.diffusion_decoder.
|
28 |
-
from cosmos1.models.autoregressive.diffusion_decoder.
|
29 |
from inference_utils import (
|
30 |
load_network_model,
|
31 |
load_tokenizer_model,
|
|
|
24 |
from huggingface_hub import snapshot_download
|
25 |
|
26 |
from inference_config import DiffusionDecoderSamplingConfig
|
27 |
+
from cosmos1.models.autoregressive.diffusion_decoder.ar_diffusion_decoder_inference import diffusion_decoder_process_tokens
|
28 |
+
from cosmos1.models.autoregressive.diffusion_decoder.ar_diffusion_decoder_model import LatentDiffusionDecoderModel
|
29 |
from inference_utils import (
|
30 |
load_network_model,
|
31 |
load_tokenizer_model,
|
futureworld_hf.py
CHANGED
@@ -19,15 +19,23 @@ class AutoregressiveFutureWorldConfig(PretrainedConfig):
|
|
19 |
def __init__(self, **kwargs):
|
20 |
super().__init__(**kwargs)
|
21 |
self.checkpoint_dir = kwargs.get("checkpoint_dir", "checkpoints")
|
22 |
-
self.
|
23 |
self.disable_diffusion_decoder = kwargs.get("disable_diffusion_decoder", False)
|
24 |
self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
|
25 |
self.offload_diffusion_decoder = kwargs.get("offload_diffusion_decoder", False)
|
26 |
-
self.
|
27 |
self.offload_tokenizer = kwargs.get("offload_tokenizer", False)
|
28 |
self.video_save_name = kwargs.get("video_save_name", "output")
|
29 |
self.video_save_folder = kwargs.get("video_save_folder", "outputs/")
|
30 |
-
self.seed = kwargs.get()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
# custom model class
|
33 |
class AutoregressiveFutureWorld(PreTrainedModel):
|
@@ -37,17 +45,16 @@ class AutoregressiveFutureWorld(PreTrainedModel):
|
|
37 |
super().__init__(config)
|
38 |
torch._C._jit_set_texpr_fuser_enabled(False)
|
39 |
self.config = config
|
40 |
-
inference_type = "base"
|
41 |
-
sampling_config = validate_args(config, inference_type)
|
42 |
self.pipeline = ARBaseGenerationPipeline(
|
43 |
-
inference_type=inference_type,
|
44 |
-
checkpoint_dir=self.checkpoint_dir,
|
45 |
-
checkpoint_name=self.ar_model_dir,
|
46 |
-
disable_diffusion_decoder=self.disable_diffusion_decoder,
|
47 |
-
offload_guardrail_models=self.offload_guardrail_models,
|
48 |
-
offload_diffusion_decoder=self.offload_diffusion_decoder,
|
49 |
-
offload_network=self.offload_ar_model,
|
50 |
-
offload_tokenizer=self.offload_tokenizer,
|
51 |
)
|
52 |
|
53 |
# modifed from text2world.py demo function
|
@@ -63,6 +70,12 @@ class AutoregressiveFutureWorld(PreTrainedModel):
|
|
63 |
data_resolution=data_resolution,
|
64 |
num_input_frames=num_input_frames,
|
65 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
for idx, input_filename in enumerate(input_videos):
|
68 |
inp_vid = input_videos[input_filename]
|
@@ -71,7 +84,7 @@ class AutoregressiveFutureWorld(PreTrainedModel):
|
|
71 |
out_vid = self.pipeline.generate(
|
72 |
inp_vid=inp_vid,
|
73 |
num_input_frames=num_input_frames,
|
74 |
-
seed=
|
75 |
sampling_config=sampling_config,
|
76 |
)
|
77 |
if out_vid is None:
|
@@ -80,9 +93,9 @@ class AutoregressiveFutureWorld(PreTrainedModel):
|
|
80 |
|
81 |
# Save video
|
82 |
if input_image_or_video_path:
|
83 |
-
out_vid_path = os.path.join(
|
84 |
else:
|
85 |
-
out_vid_path = os.path.join(
|
86 |
|
87 |
imageio.mimsave(out_vid_path, out_vid, fps=25)
|
88 |
|
|
|
19 |
def __init__(self, **kwargs):
|
20 |
super().__init__(**kwargs)
|
21 |
self.checkpoint_dir = kwargs.get("checkpoint_dir", "checkpoints")
|
22 |
+
self.ar_model_dir = kwargs.get("ar_model_dir", "Cosmos-1.0-Autoregressive-4B")
|
23 |
self.disable_diffusion_decoder = kwargs.get("disable_diffusion_decoder", False)
|
24 |
self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
|
25 |
self.offload_diffusion_decoder = kwargs.get("offload_diffusion_decoder", False)
|
26 |
+
self.offload_ar_model = kwargs.get("offload_ar_model", False)
|
27 |
self.offload_tokenizer = kwargs.get("offload_tokenizer", False)
|
28 |
self.video_save_name = kwargs.get("video_save_name", "output")
|
29 |
self.video_save_folder = kwargs.get("video_save_folder", "outputs/")
|
30 |
+
self.seed = kwargs.get("seed", 0)
|
31 |
+
self.temperature = kwargs.get("temperature", 1.0)
|
32 |
+
self.top_p = kwargs.get("top_p", 0.8)
|
33 |
+
self.input_type = None
|
34 |
+
self.batch_input_path = None
|
35 |
+
self.input_image_or_video_path = None
|
36 |
+
self.data_resolution = None
|
37 |
+
self.num_input_frames = None
|
38 |
+
|
39 |
|
40 |
# custom model class
|
41 |
class AutoregressiveFutureWorld(PreTrainedModel):
|
|
|
45 |
super().__init__(config)
|
46 |
torch._C._jit_set_texpr_fuser_enabled(False)
|
47 |
self.config = config
|
48 |
+
self.inference_type = "base"
|
|
|
49 |
self.pipeline = ARBaseGenerationPipeline(
|
50 |
+
inference_type=self.inference_type,
|
51 |
+
checkpoint_dir=self.config.checkpoint_dir,
|
52 |
+
checkpoint_name=self.config.ar_model_dir,
|
53 |
+
disable_diffusion_decoder=self.config.disable_diffusion_decoder,
|
54 |
+
offload_guardrail_models=self.config.offload_guardrail_models,
|
55 |
+
offload_diffusion_decoder=self.config.offload_diffusion_decoder,
|
56 |
+
offload_network=self.config.offload_ar_model,
|
57 |
+
offload_tokenizer=self.config.offload_tokenizer,
|
58 |
)
|
59 |
|
60 |
# modifed from text2world.py demo function
|
|
|
70 |
data_resolution=data_resolution,
|
71 |
num_input_frames=num_input_frames,
|
72 |
)
|
73 |
+
self.config.input_type = input_type
|
74 |
+
self.config.batch_input_path = batch_input_path
|
75 |
+
self.config.input_image_or_video_path = input_image_or_video_path
|
76 |
+
self.config.data_resolution = data_resolution
|
77 |
+
self.config.num_input_frames = num_input_frames
|
78 |
+
sampling_config = validate_args(self.config, self.inference_type)
|
79 |
|
80 |
for idx, input_filename in enumerate(input_videos):
|
81 |
inp_vid = input_videos[input_filename]
|
|
|
84 |
out_vid = self.pipeline.generate(
|
85 |
inp_vid=inp_vid,
|
86 |
num_input_frames=num_input_frames,
|
87 |
+
seed=self.config.seed,
|
88 |
sampling_config=sampling_config,
|
89 |
)
|
90 |
if out_vid is None:
|
|
|
93 |
|
94 |
# Save video
|
95 |
if input_image_or_video_path:
|
96 |
+
out_vid_path = os.path.join(self.config.video_save_folder, f"{self.config.video_save_name}.mp4")
|
97 |
else:
|
98 |
+
out_vid_path = os.path.join(self.config.video_save_folder, f"{idx}.mp4")
|
99 |
|
100 |
imageio.mimsave(out_vid_path, out_vid, fps=25)
|
101 |
|
text2world_prompt_upsampler_inference.py
CHANGED
@@ -23,7 +23,7 @@ import argparse
|
|
23 |
import os
|
24 |
import re
|
25 |
|
26 |
-
from .
|
27 |
from .ar_model import AutoRegressiveModel
|
28 |
from .inference import chat_completion
|
29 |
from .presets import presets as guardrail_presets
|
|
|
23 |
import os
|
24 |
import re
|
25 |
|
26 |
+
from .ar_configs_model_config import create_text_model_config
|
27 |
from .ar_model import AutoRegressiveModel
|
28 |
from .inference import chat_completion
|
29 |
from .presets import presets as guardrail_presets
|
video2world_prompt_upsampler_inference.py
CHANGED
@@ -26,7 +26,7 @@ from math import ceil
|
|
26 |
|
27 |
from PIL import Image
|
28 |
|
29 |
-
from .
|
30 |
from .ar_model import AutoRegressiveModel
|
31 |
from .inference import chat_completion
|
32 |
from .presets import presets as guardrail_presets
|
|
|
26 |
|
27 |
from PIL import Image
|
28 |
|
29 |
+
from .ar_configs_model_config import create_vision_language_model_config
|
30 |
from .ar_model import AutoRegressiveModel
|
31 |
from .inference import chat_completion
|
32 |
from .presets import presets as guardrail_presets
|
world_generation_pipeline.py
CHANGED
@@ -21,25 +21,26 @@ import numpy as np
|
|
21 |
import torch
|
22 |
from einops import rearrange
|
23 |
|
24 |
-
from
|
25 |
-
from
|
26 |
-
from
|
27 |
DataShapeConfig,
|
28 |
DiffusionDecoderSamplingConfig,
|
29 |
InferenceConfig,
|
30 |
SamplingConfig,
|
31 |
)
|
32 |
-
from
|
33 |
-
from
|
34 |
-
from
|
35 |
-
from
|
36 |
-
from
|
37 |
-
from
|
38 |
load_model_by_config,
|
39 |
load_network_model,
|
40 |
load_tokenizer_model,
|
41 |
)
|
42 |
-
from
|
|
|
43 |
|
44 |
|
45 |
def detect_model_size_from_ckpt_path(ckpt_path: str) -> str:
|
|
|
21 |
import torch
|
22 |
from einops import rearrange
|
23 |
|
24 |
+
from .ar_configs_model_config import create_video2world_model_config
|
25 |
+
from .ar_config_tokenizer import TokenizerConfig
|
26 |
+
from .ar_configs_inference import (
|
27 |
DataShapeConfig,
|
28 |
DiffusionDecoderSamplingConfig,
|
29 |
InferenceConfig,
|
30 |
SamplingConfig,
|
31 |
)
|
32 |
+
from .ar_diffusion_decoder_inference import diffusion_decoder_process_tokens
|
33 |
+
from .ar_diffusion_decoder_model import LatentDiffusionDecoderModel
|
34 |
+
from .ar_model import AutoRegressiveModel
|
35 |
+
from .ar_utils_inference import _SUPPORTED_CONTEXT_LEN, prepare_video_batch_for_saving
|
36 |
+
from .base_world_generation_pipeline import BaseWorldGenerationPipeline
|
37 |
+
from .inference_utils import (
|
38 |
load_model_by_config,
|
39 |
load_network_model,
|
40 |
load_tokenizer_model,
|
41 |
)
|
42 |
+
from .log import log
|
43 |
+
from .misc import misc
|
44 |
|
45 |
|
46 |
def detect_model_size_from_ckpt_path(ckpt_path: str) -> str:
|