modify misc
Browse files- aegis.py +4 -4
- ar_model.py +5 -5
- blocklist.py +3 -3
- face_blur_filter.py +2 -2
- inference_utils.py +3 -3
- misc.py +102 -96
- model_t2w.py +2 -2
- model_v2w.py +1 -1
- text2world.py +1 -1
- text2world_hf.py +1 -1
- video2world.py +1 -1
- video_content_safety_filter.py +2 -2
aegis.py
CHANGED
@@ -22,10 +22,10 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
22 |
|
23 |
from .categories import UNSAFE_CATEGORIES
|
24 |
from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner
|
25 |
-
from . import misc
|
26 |
|
27 |
-
SAFE =
|
28 |
-
UNSAFE =
|
29 |
|
30 |
DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/aegis"
|
31 |
|
@@ -120,7 +120,7 @@ def parse_args():
|
|
120 |
def main(args):
|
121 |
aegis = Aegis(checkpoint_dir=args.checkpoint_dir)
|
122 |
runner = GuardrailRunner(safety_models=[aegis])
|
123 |
-
with
|
124 |
safety, message = runner.run_safety_check(args.prompt)
|
125 |
log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
|
126 |
log.info(f"Message: {message}") if not safety else None
|
|
|
22 |
|
23 |
from .categories import UNSAFE_CATEGORIES
|
24 |
from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner
|
25 |
+
from .misc import misc, Color, timer
|
26 |
|
27 |
+
SAFE = Color.green("SAFE")
|
28 |
+
UNSAFE = Color.red("UNSAFE")
|
29 |
|
30 |
DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/aegis"
|
31 |
|
|
|
120 |
def main(args):
|
121 |
aegis = Aegis(checkpoint_dir=args.checkpoint_dir)
|
122 |
runner = GuardrailRunner(safety_models=[aegis])
|
123 |
+
with timer("aegis safety check"):
|
124 |
safety, message = runner.run_safety_check(args.prompt)
|
125 |
log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
|
126 |
log.info(f"Message: {message}") if not safety else None
|
ar_model.py
CHANGED
@@ -36,7 +36,7 @@ from .checkpoint import (
|
|
36 |
substrings_to_ignore,
|
37 |
)
|
38 |
from .sampling import decode_n_tokens, decode_one_token, prefill
|
39 |
-
from . import misc
|
40 |
|
41 |
|
42 |
class AutoRegressiveModel(torch.nn.Module):
|
@@ -96,7 +96,7 @@ class AutoRegressiveModel(torch.nn.Module):
|
|
96 |
"""
|
97 |
model_config = self.config
|
98 |
ckpt_path = model_config.ckpt_path
|
99 |
-
with
|
100 |
if ckpt_path.endswith("safetensors"):
|
101 |
# Load with safetensors API
|
102 |
checkpoint = load_file(ckpt_path, device="cpu")
|
@@ -142,7 +142,7 @@ class AutoRegressiveModel(torch.nn.Module):
|
|
142 |
)
|
143 |
# Remove the "model." prefix in the state_dict
|
144 |
llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
|
145 |
-
with
|
146 |
missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True)
|
147 |
# Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
|
148 |
missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
|
@@ -217,7 +217,7 @@ class AutoRegressiveModel(torch.nn.Module):
|
|
217 |
# Override the default model configuration with the parameters from the checkpoint
|
218 |
setattr(model_config, key, value)
|
219 |
|
220 |
-
with
|
221 |
if ckpt_path.endswith("safetensors"):
|
222 |
# Load with safetensors API
|
223 |
checkpoint = load_file(ckpt_path, device="cpu")
|
@@ -293,7 +293,7 @@ class AutoRegressiveModel(torch.nn.Module):
|
|
293 |
|
294 |
# Remove the "model." prefix in the state_dict
|
295 |
llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
|
296 |
-
with
|
297 |
missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True)
|
298 |
# Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
|
299 |
missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
|
|
|
36 |
substrings_to_ignore,
|
37 |
)
|
38 |
from .sampling import decode_n_tokens, decode_one_token, prefill
|
39 |
+
from .misc import misc, Color, timer
|
40 |
|
41 |
|
42 |
class AutoRegressiveModel(torch.nn.Module):
|
|
|
96 |
"""
|
97 |
model_config = self.config
|
98 |
ckpt_path = model_config.ckpt_path
|
99 |
+
with timer(f"loading checkpoint from {ckpt_path}"):
|
100 |
if ckpt_path.endswith("safetensors"):
|
101 |
# Load with safetensors API
|
102 |
checkpoint = load_file(ckpt_path, device="cpu")
|
|
|
142 |
)
|
143 |
# Remove the "model." prefix in the state_dict
|
144 |
llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
|
145 |
+
with timer("loading state_dict into model"):
|
146 |
missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True)
|
147 |
# Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
|
148 |
missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
|
|
|
217 |
# Override the default model configuration with the parameters from the checkpoint
|
218 |
setattr(model_config, key, value)
|
219 |
|
220 |
+
with timer(f"loading checkpoint from {ckpt_path}"):
|
221 |
if ckpt_path.endswith("safetensors"):
|
222 |
# Load with safetensors API
|
223 |
checkpoint = load_file(ckpt_path, device="cpu")
|
|
|
293 |
|
294 |
# Remove the "model." prefix in the state_dict
|
295 |
llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
|
296 |
+
with timer("loading state_dict into model"):
|
297 |
missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True)
|
298 |
# Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
|
299 |
missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
|
blocklist.py
CHANGED
@@ -25,10 +25,10 @@ from better_profanity import profanity
|
|
25 |
|
26 |
from .guardrail_blocklist_utils import read_keyword_list_from_dir, to_ascii
|
27 |
from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner
|
28 |
-
from . import misc
|
29 |
|
30 |
DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/blocklist"
|
31 |
-
CENSOR =
|
32 |
|
33 |
|
34 |
class Blocklist(ContentSafetyGuardrail):
|
@@ -208,7 +208,7 @@ def parse_args():
|
|
208 |
def main(args):
|
209 |
blocklist = Blocklist(checkpoint_dir=args.checkpoint_dir)
|
210 |
runner = GuardrailRunner(safety_models=[blocklist])
|
211 |
-
with
|
212 |
safety, message = runner.run_safety_check(args.prompt)
|
213 |
log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
|
214 |
log.info(f"Message: {message}") if not safety else None
|
|
|
25 |
|
26 |
from .guardrail_blocklist_utils import read_keyword_list_from_dir, to_ascii
|
27 |
from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner
|
28 |
+
from .misc import misc, Color, timer
|
29 |
|
30 |
DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/blocklist"
|
31 |
+
CENSOR = Color.red("*")
|
32 |
|
33 |
|
34 |
class Blocklist(ContentSafetyGuardrail):
|
|
|
208 |
def main(args):
|
209 |
blocklist = Blocklist(checkpoint_dir=args.checkpoint_dir)
|
210 |
runner = GuardrailRunner(safety_models=[blocklist])
|
211 |
+
with timer("blocklist safety check"):
|
212 |
safety, message = runner.run_safety_check(args.prompt)
|
213 |
log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
|
214 |
log.info(f"Message: {message}") if not safety else None
|
face_blur_filter.py
CHANGED
@@ -29,7 +29,7 @@ from .guardrail_core import GuardrailRunner, PostprocessingGuardrail
|
|
29 |
from .guardrail_io_utils import get_video_filepaths, read_video, save_video
|
30 |
from .blur_utils import pixelate_face
|
31 |
from .retinaface_utils import decode_batch, filter_detected_boxes, load_model
|
32 |
-
from . import misc
|
33 |
|
34 |
DEFAULT_RETINAFACE_CHECKPOINT = "checkpoints/Cosmos-1.0-Guardrail/face_blur_filter/Resnet50_Final.pth"
|
35 |
|
@@ -212,7 +212,7 @@ def main(args):
|
|
212 |
|
213 |
for filepath in tqdm(filepaths):
|
214 |
video_data = read_video(filepath)
|
215 |
-
with
|
216 |
frames = postprocessing_runner.postprocess(video_data.frames)
|
217 |
|
218 |
output_path = os.path.join(args.output_dir, os.path.basename(filepath))
|
|
|
29 |
from .guardrail_io_utils import get_video_filepaths, read_video, save_video
|
30 |
from .blur_utils import pixelate_face
|
31 |
from .retinaface_utils import decode_batch, filter_detected_boxes, load_model
|
32 |
+
from .misc import misc, Color, timer
|
33 |
|
34 |
DEFAULT_RETINAFACE_CHECKPOINT = "checkpoints/Cosmos-1.0-Guardrail/face_blur_filter/Resnet50_Final.pth"
|
35 |
|
|
|
212 |
|
213 |
for filepath in tqdm(filepaths):
|
214 |
video_data = read_video(filepath)
|
215 |
+
with timer("face blur filter"):
|
216 |
frames = postprocessing_runner.postprocess(video_data.frames)
|
217 |
|
218 |
output_path = os.path.join(args.output_dir, os.path.basename(filepath))
|
inference_utils.py
CHANGED
@@ -28,7 +28,7 @@ from .model_t2w import DiffusionT2WModel
|
|
28 |
from .model_v2w import DiffusionV2WModel
|
29 |
from .config_helper import get_config_module, override
|
30 |
from .utils_io import load_from_fileobj
|
31 |
-
from .misc import
|
32 |
|
33 |
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
|
34 |
if TORCH_VERSION >= (1, 11):
|
@@ -418,7 +418,7 @@ def generate_world_from_text(
|
|
418 |
3. Decodes latents to pixel space
|
419 |
"""
|
420 |
x_sigma_max = (
|
421 |
-
arch_invariant_rand(
|
422 |
(1,) + tuple(state_shape),
|
423 |
torch.float32,
|
424 |
model.tensor_kwargs["device"],
|
@@ -484,7 +484,7 @@ def generate_world_from_video(
|
|
484 |
num_of_latent_condition = compute_num_latent_frames(model, num_input_frames)
|
485 |
|
486 |
x_sigma_max = (
|
487 |
-
arch_invariant_rand(
|
488 |
(1,) + tuple(state_shape),
|
489 |
torch.float32,
|
490 |
model.tensor_kwargs["device"],
|
|
|
28 |
from .model_v2w import DiffusionV2WModel
|
29 |
from .config_helper import get_config_module, override
|
30 |
from .utils_io import load_from_fileobj
|
31 |
+
from .misc import misc
|
32 |
|
33 |
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
|
34 |
if TORCH_VERSION >= (1, 11):
|
|
|
418 |
3. Decodes latents to pixel space
|
419 |
"""
|
420 |
x_sigma_max = (
|
421 |
+
misc.arch_invariant_rand(
|
422 |
(1,) + tuple(state_shape),
|
423 |
torch.float32,
|
424 |
model.tensor_kwargs["device"],
|
|
|
484 |
num_of_latent_condition = compute_num_latent_frames(model, num_input_frames)
|
485 |
|
486 |
x_sigma_max = (
|
487 |
+
misc.arch_invariant_rand(
|
488 |
(1,) + tuple(state_shape),
|
489 |
torch.float32,
|
490 |
model.tensor_kwargs["device"],
|
misc.py
CHANGED
@@ -29,109 +29,115 @@ import numpy as np
|
|
29 |
import termcolor
|
30 |
import torch
|
31 |
|
32 |
-
from . import
|
33 |
|
34 |
|
35 |
-
|
36 |
-
data: Any,
|
37 |
-
device: str | torch.device | None = None,
|
38 |
-
dtype: torch.dtype | None = None,
|
39 |
-
memory_format: torch.memory_format = torch.preserve_format,
|
40 |
-
) -> Any:
|
41 |
-
"""Recursively cast data into the specified device, dtype, and/or memory_format.
|
42 |
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
Args:
|
47 |
-
data (Any): Input data.
|
48 |
-
device (str | torch.device): GPU device (default: None).
|
49 |
-
dtype (torch.dtype): data type (default: None).
|
50 |
-
memory_format (torch.memory_format): memory organization format (default: torch.preserve_format).
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
data
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
|
89 |
-
return type(data)([serialize(elem) for elem in data])
|
90 |
-
else:
|
91 |
-
try:
|
92 |
-
json.dumps(data)
|
93 |
-
except TypeError:
|
94 |
-
data = str(data)
|
95 |
-
return data
|
96 |
-
|
97 |
-
|
98 |
-
def set_random_seed(seed: int, by_rank: bool = False) -> None:
|
99 |
-
"""Set random seed. This includes random, numpy, Pytorch.
|
100 |
-
|
101 |
-
Args:
|
102 |
-
seed (int): Random seed.
|
103 |
-
by_rank (bool): if true, each GPU will use a different random seed.
|
104 |
-
"""
|
105 |
-
if by_rank:
|
106 |
-
seed += distributed.get_rank()
|
107 |
-
log.info(f"Using random seed {seed}.")
|
108 |
-
random.seed(seed)
|
109 |
-
np.random.seed(seed)
|
110 |
-
torch.manual_seed(seed) # sets seed on the current CPU & all GPUs
|
111 |
-
|
112 |
-
|
113 |
-
def arch_invariant_rand(
|
114 |
-
shape: List[int] | Tuple[int], dtype: torch.dtype, device: str | torch.device, seed: int | None = None
|
115 |
-
):
|
116 |
-
"""Produce a GPU-architecture-invariant randomized Torch tensor.
|
117 |
-
|
118 |
-
Args:
|
119 |
-
shape (list or tuple of ints): Output tensor shape.
|
120 |
-
dtype (torch.dtype): Output tensor type.
|
121 |
-
device (torch.device): Device holding the output.
|
122 |
-
seed (int): Optional randomization seed.
|
123 |
-
|
124 |
-
Returns:
|
125 |
-
tensor (torch.tensor): Randomly-generated tensor.
|
126 |
-
"""
|
127 |
-
# Create a random number generator, optionally seeded
|
128 |
-
rng = np.random.RandomState(seed)
|
129 |
|
130 |
-
# # Generate random numbers using the generator
|
131 |
-
random_array = rng.standard_normal(shape).astype(np.float32) # Use standard_normal for normal distribution
|
132 |
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
|
137 |
T = TypeVar("T", bound=Callable[..., Any])
|
|
|
29 |
import termcolor
|
30 |
import torch
|
31 |
|
32 |
+
from .distributed import get_rank
|
33 |
|
34 |
|
35 |
+
class misc():
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
@staticmethod
|
38 |
+
def to(
|
39 |
+
data: Any,
|
40 |
+
device: str | torch.device | None = None,
|
41 |
+
dtype: torch.dtype | None = None,
|
42 |
+
memory_format: torch.memory_format = torch.preserve_format,
|
43 |
+
) -> Any:
|
44 |
+
"""Recursively cast data into the specified device, dtype, and/or memory_format.
|
45 |
+
|
46 |
+
The input data can be a tensor, a list of tensors, a dict of tensors.
|
47 |
+
See the documentation for torch.Tensor.to() for details.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
data (Any): Input data.
|
51 |
+
device (str | torch.device): GPU device (default: None).
|
52 |
+
dtype (torch.dtype): data type (default: None).
|
53 |
+
memory_format (torch.memory_format): memory organization format (default: torch.preserve_format).
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
data (Any): Data cast to the specified device, dtype, and/or memory_format.
|
57 |
+
"""
|
58 |
+
assert (
|
59 |
+
device is not None or dtype is not None or memory_format is not None
|
60 |
+
), "at least one of device, dtype, memory_format should be specified"
|
61 |
+
if isinstance(data, torch.Tensor):
|
62 |
+
is_cpu = (isinstance(device, str) and device == "cpu") or (
|
63 |
+
isinstance(device, torch.device) and device.type == "cpu"
|
64 |
+
)
|
65 |
+
data = data.to(
|
66 |
+
device=device,
|
67 |
+
dtype=dtype,
|
68 |
+
memory_format=memory_format,
|
69 |
+
non_blocking=(not is_cpu),
|
70 |
+
)
|
71 |
+
return data
|
72 |
+
elif isinstance(data, collections.abc.Mapping):
|
73 |
+
return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data})
|
74 |
+
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
|
75 |
+
return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data])
|
76 |
+
else:
|
77 |
+
return data
|
78 |
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
@staticmethod
|
81 |
+
def serialize(data: Any) -> Any:
|
82 |
+
"""Serialize data by hierarchically traversing through iterables.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
data (Any): Input data.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
data (Any): Serialized data.
|
89 |
+
"""
|
90 |
+
if isinstance(data, collections.abc.Mapping):
|
91 |
+
return type(data)({key: serialize(data[key]) for key in data})
|
92 |
+
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
|
93 |
+
return type(data)([serialize(elem) for elem in data])
|
94 |
+
else:
|
95 |
+
try:
|
96 |
+
json.dumps(data)
|
97 |
+
except TypeError:
|
98 |
+
data = str(data)
|
99 |
+
return data
|
100 |
+
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def set_random_seed(seed: int, by_rank: bool = False) -> None:
|
104 |
+
"""Set random seed. This includes random, numpy, Pytorch.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
seed (int): Random seed.
|
108 |
+
by_rank (bool): if true, each GPU will use a different random seed.
|
109 |
+
"""
|
110 |
+
if by_rank:
|
111 |
+
seed += get_rank()
|
112 |
+
log.info(f"Using random seed {seed}.")
|
113 |
+
random.seed(seed)
|
114 |
+
np.random.seed(seed)
|
115 |
+
torch.manual_seed(seed) # sets seed on the current CPU & all GPUs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
|
|
|
|
117 |
|
118 |
+
@staticmethod
|
119 |
+
def arch_invariant_rand(
|
120 |
+
shape: List[int] | Tuple[int], dtype: torch.dtype, device: str | torch.device, seed: int | None = None
|
121 |
+
):
|
122 |
+
"""Produce a GPU-architecture-invariant randomized Torch tensor.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
shape (list or tuple of ints): Output tensor shape.
|
126 |
+
dtype (torch.dtype): Output tensor type.
|
127 |
+
device (torch.device): Device holding the output.
|
128 |
+
seed (int): Optional randomization seed.
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
tensor (torch.tensor): Randomly-generated tensor.
|
132 |
+
"""
|
133 |
+
# Create a random number generator, optionally seeded
|
134 |
+
rng = np.random.RandomState(seed)
|
135 |
+
|
136 |
+
# # Generate random numbers using the generator
|
137 |
+
random_array = rng.standard_normal(shape).astype(np.float32) # Use standard_normal for normal distribution
|
138 |
+
|
139 |
+
# Convert to torch tensor and return
|
140 |
+
return torch.from_numpy(random_array).to(dtype=dtype, device=device)
|
141 |
|
142 |
|
143 |
T = TypeVar("T", bound=Callable[..., Any])
|
model_t2w.py
CHANGED
@@ -25,7 +25,7 @@ from .res_sampler import COMMON_SOLVER_OPTIONS, Sampler
|
|
25 |
from .diffusion_types import DenoisePrediction
|
26 |
from .blocks import FourierFeatures
|
27 |
from .pretrained_vae import BaseVAE
|
28 |
-
from . import misc
|
29 |
from . import instantiate as lazy_instantiate
|
30 |
from .log import log
|
31 |
|
@@ -96,7 +96,7 @@ class DiffusionT2WModel(torch.nn.Module):
|
|
96 |
if hasattr(self.tokenizer, "reset_dtype"):
|
97 |
self.tokenizer.reset_dtype()
|
98 |
|
99 |
-
@
|
100 |
def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format):
|
101 |
"""Initialize the core model components including network, conditioner and logvar."""
|
102 |
self.model = self.build_model()
|
|
|
25 |
from .diffusion_types import DenoisePrediction
|
26 |
from .blocks import FourierFeatures
|
27 |
from .pretrained_vae import BaseVAE
|
28 |
+
from .misc import misc, Color, timer
|
29 |
from . import instantiate as lazy_instantiate
|
30 |
from .log import log
|
31 |
|
|
|
96 |
if hasattr(self.tokenizer, "reset_dtype"):
|
97 |
self.tokenizer.reset_dtype()
|
98 |
|
99 |
+
@timer("DiffusionModel: set_up_model")
|
100 |
def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format):
|
101 |
"""Initialize the core model components including network, conditioner and logvar."""
|
102 |
self.model = self.build_model()
|
model_v2w.py
CHANGED
@@ -24,7 +24,7 @@ from .conditioner import VideoExtendCondition
|
|
24 |
from .config_base_conditioner import VideoCondBoolConfig
|
25 |
from .batch_ops import batch_mul
|
26 |
from .model_t2w import DiffusionT2WModel
|
27 |
-
from . import misc
|
28 |
|
29 |
|
30 |
@dataclass
|
|
|
24 |
from .config_base_conditioner import VideoCondBoolConfig
|
25 |
from .batch_ops import batch_mul
|
26 |
from .model_t2w import DiffusionT2WModel
|
27 |
+
from .misc import misc, Color, timer
|
28 |
|
29 |
|
30 |
@dataclass
|
text2world.py
CHANGED
@@ -21,7 +21,7 @@ import torch
|
|
21 |
|
22 |
from .inference_utils import add_common_arguments, validate_args
|
23 |
from .world_generation_pipeline import DiffusionText2WorldGenerationPipeline
|
24 |
-
from . import misc
|
25 |
from .utils_io import read_prompts_from_file, save_video
|
26 |
|
27 |
torch.enable_grad(False)
|
|
|
21 |
|
22 |
from .inference_utils import add_common_arguments, validate_args
|
23 |
from .world_generation_pipeline import DiffusionText2WorldGenerationPipeline
|
24 |
+
from .misc import misc, Color, timer
|
25 |
from .utils_io import read_prompts_from_file, save_video
|
26 |
|
27 |
torch.enable_grad(False)
|
text2world_hf.py
CHANGED
@@ -6,7 +6,7 @@ from transformers import PreTrainedModel, PretrainedConfig
|
|
6 |
from .inference_utils import add_common_arguments, validate_args
|
7 |
from .world_generation_pipeline import DiffusionText2WorldGenerationPipeline
|
8 |
from .log import log
|
9 |
-
from . import misc
|
10 |
from .utils_io import read_prompts_from_file, save_video
|
11 |
|
12 |
|
|
|
6 |
from .inference_utils import add_common_arguments, validate_args
|
7 |
from .world_generation_pipeline import DiffusionText2WorldGenerationPipeline
|
8 |
from .log import log
|
9 |
+
from .misc import misc, Color, timer
|
10 |
from .utils_io import read_prompts_from_file, save_video
|
11 |
|
12 |
|
video2world.py
CHANGED
@@ -21,7 +21,7 @@ import torch
|
|
21 |
|
22 |
from .inference_utils import add_common_arguments, check_input_frames, validate_args
|
23 |
from .world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline
|
24 |
-
from . import misc
|
25 |
from .utils_io import read_prompts_from_file, save_video
|
26 |
|
27 |
torch.enable_grad(False)
|
|
|
21 |
|
22 |
from .inference_utils import add_common_arguments, check_input_frames, validate_args
|
23 |
from .world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline
|
24 |
+
from .misc import misc, Color, timer
|
25 |
from .utils_io import read_prompts_from_file, save_video
|
26 |
|
27 |
torch.enable_grad(False)
|
video_content_safety_filter.py
CHANGED
@@ -26,7 +26,7 @@ from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner
|
|
26 |
from .guardrail_io_utils import get_video_filepaths, read_video
|
27 |
from .video_content_safety_filter_model import ModelConfig, VideoSafetyModel
|
28 |
from .video_content_safety_filter_vision_encoder import SigLIPEncoder
|
29 |
-
from . import misc
|
30 |
|
31 |
DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/video_content_safety_filter"
|
32 |
|
@@ -178,7 +178,7 @@ def main(args):
|
|
178 |
runner = GuardrailRunner(safety_models=[video_filter], generic_safe_msg="Video is safe")
|
179 |
|
180 |
for filepath in filepaths:
|
181 |
-
with
|
182 |
_ = runner.run_safety_check(filepath)
|
183 |
|
184 |
|
|
|
26 |
from .guardrail_io_utils import get_video_filepaths, read_video
|
27 |
from .video_content_safety_filter_model import ModelConfig, VideoSafetyModel
|
28 |
from .video_content_safety_filter_vision_encoder import SigLIPEncoder
|
29 |
+
from .misc import misc, Color, timer
|
30 |
|
31 |
DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/video_content_safety_filter"
|
32 |
|
|
|
178 |
runner = GuardrailRunner(safety_models=[video_filter], generic_safe_msg="Video is safe")
|
179 |
|
180 |
for filepath in filepaths:
|
181 |
+
with timer("video content safety filter"):
|
182 |
_ = runner.run_safety_check(filepath)
|
183 |
|
184 |
|