EthanZyh commited on
Commit
62c1236
·
1 Parent(s): 02c5b0e

modify misc

Browse files
aegis.py CHANGED
@@ -22,10 +22,10 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
22
 
23
  from .categories import UNSAFE_CATEGORIES
24
  from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner
25
- from . import misc
26
 
27
- SAFE = misc.Color.green("SAFE")
28
- UNSAFE = misc.Color.red("UNSAFE")
29
 
30
  DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/aegis"
31
 
@@ -120,7 +120,7 @@ def parse_args():
120
  def main(args):
121
  aegis = Aegis(checkpoint_dir=args.checkpoint_dir)
122
  runner = GuardrailRunner(safety_models=[aegis])
123
- with misc.timer("aegis safety check"):
124
  safety, message = runner.run_safety_check(args.prompt)
125
  log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
126
  log.info(f"Message: {message}") if not safety else None
 
22
 
23
  from .categories import UNSAFE_CATEGORIES
24
  from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner
25
+ from .misc import misc, Color, timer
26
 
27
+ SAFE = Color.green("SAFE")
28
+ UNSAFE = Color.red("UNSAFE")
29
 
30
  DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/aegis"
31
 
 
120
  def main(args):
121
  aegis = Aegis(checkpoint_dir=args.checkpoint_dir)
122
  runner = GuardrailRunner(safety_models=[aegis])
123
+ with timer("aegis safety check"):
124
  safety, message = runner.run_safety_check(args.prompt)
125
  log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
126
  log.info(f"Message: {message}") if not safety else None
ar_model.py CHANGED
@@ -36,7 +36,7 @@ from .checkpoint import (
36
  substrings_to_ignore,
37
  )
38
  from .sampling import decode_n_tokens, decode_one_token, prefill
39
- from . import misc
40
 
41
 
42
  class AutoRegressiveModel(torch.nn.Module):
@@ -96,7 +96,7 @@ class AutoRegressiveModel(torch.nn.Module):
96
  """
97
  model_config = self.config
98
  ckpt_path = model_config.ckpt_path
99
- with misc.timer(f"loading checkpoint from {ckpt_path}"):
100
  if ckpt_path.endswith("safetensors"):
101
  # Load with safetensors API
102
  checkpoint = load_file(ckpt_path, device="cpu")
@@ -142,7 +142,7 @@ class AutoRegressiveModel(torch.nn.Module):
142
  )
143
  # Remove the "model." prefix in the state_dict
144
  llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
145
- with misc.timer("loading state_dict into model"):
146
  missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True)
147
  # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
148
  missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
@@ -217,7 +217,7 @@ class AutoRegressiveModel(torch.nn.Module):
217
  # Override the default model configuration with the parameters from the checkpoint
218
  setattr(model_config, key, value)
219
 
220
- with misc.timer(f"loading checkpoint from {ckpt_path}"):
221
  if ckpt_path.endswith("safetensors"):
222
  # Load with safetensors API
223
  checkpoint = load_file(ckpt_path, device="cpu")
@@ -293,7 +293,7 @@ class AutoRegressiveModel(torch.nn.Module):
293
 
294
  # Remove the "model." prefix in the state_dict
295
  llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
296
- with misc.timer("loading state_dict into model"):
297
  missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True)
298
  # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
299
  missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
 
36
  substrings_to_ignore,
37
  )
38
  from .sampling import decode_n_tokens, decode_one_token, prefill
39
+ from .misc import misc, Color, timer
40
 
41
 
42
  class AutoRegressiveModel(torch.nn.Module):
 
96
  """
97
  model_config = self.config
98
  ckpt_path = model_config.ckpt_path
99
+ with timer(f"loading checkpoint from {ckpt_path}"):
100
  if ckpt_path.endswith("safetensors"):
101
  # Load with safetensors API
102
  checkpoint = load_file(ckpt_path, device="cpu")
 
142
  )
143
  # Remove the "model." prefix in the state_dict
144
  llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
145
+ with timer("loading state_dict into model"):
146
  missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True)
147
  # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
148
  missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
 
217
  # Override the default model configuration with the parameters from the checkpoint
218
  setattr(model_config, key, value)
219
 
220
+ with timer(f"loading checkpoint from {ckpt_path}"):
221
  if ckpt_path.endswith("safetensors"):
222
  # Load with safetensors API
223
  checkpoint = load_file(ckpt_path, device="cpu")
 
293
 
294
  # Remove the "model." prefix in the state_dict
295
  llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
296
+ with timer("loading state_dict into model"):
297
  missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True)
298
  # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
299
  missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
blocklist.py CHANGED
@@ -25,10 +25,10 @@ from better_profanity import profanity
25
 
26
  from .guardrail_blocklist_utils import read_keyword_list_from_dir, to_ascii
27
  from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner
28
- from . import misc
29
 
30
  DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/blocklist"
31
- CENSOR = misc.Color.red("*")
32
 
33
 
34
  class Blocklist(ContentSafetyGuardrail):
@@ -208,7 +208,7 @@ def parse_args():
208
  def main(args):
209
  blocklist = Blocklist(checkpoint_dir=args.checkpoint_dir)
210
  runner = GuardrailRunner(safety_models=[blocklist])
211
- with misc.timer("blocklist safety check"):
212
  safety, message = runner.run_safety_check(args.prompt)
213
  log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
214
  log.info(f"Message: {message}") if not safety else None
 
25
 
26
  from .guardrail_blocklist_utils import read_keyword_list_from_dir, to_ascii
27
  from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner
28
+ from .misc import misc, Color, timer
29
 
30
  DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/blocklist"
31
+ CENSOR = Color.red("*")
32
 
33
 
34
  class Blocklist(ContentSafetyGuardrail):
 
208
  def main(args):
209
  blocklist = Blocklist(checkpoint_dir=args.checkpoint_dir)
210
  runner = GuardrailRunner(safety_models=[blocklist])
211
+ with timer("blocklist safety check"):
212
  safety, message = runner.run_safety_check(args.prompt)
213
  log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
214
  log.info(f"Message: {message}") if not safety else None
face_blur_filter.py CHANGED
@@ -29,7 +29,7 @@ from .guardrail_core import GuardrailRunner, PostprocessingGuardrail
29
  from .guardrail_io_utils import get_video_filepaths, read_video, save_video
30
  from .blur_utils import pixelate_face
31
  from .retinaface_utils import decode_batch, filter_detected_boxes, load_model
32
- from . import misc
33
 
34
  DEFAULT_RETINAFACE_CHECKPOINT = "checkpoints/Cosmos-1.0-Guardrail/face_blur_filter/Resnet50_Final.pth"
35
 
@@ -212,7 +212,7 @@ def main(args):
212
 
213
  for filepath in tqdm(filepaths):
214
  video_data = read_video(filepath)
215
- with misc.timer("face blur filter"):
216
  frames = postprocessing_runner.postprocess(video_data.frames)
217
 
218
  output_path = os.path.join(args.output_dir, os.path.basename(filepath))
 
29
  from .guardrail_io_utils import get_video_filepaths, read_video, save_video
30
  from .blur_utils import pixelate_face
31
  from .retinaface_utils import decode_batch, filter_detected_boxes, load_model
32
+ from .misc import misc, Color, timer
33
 
34
  DEFAULT_RETINAFACE_CHECKPOINT = "checkpoints/Cosmos-1.0-Guardrail/face_blur_filter/Resnet50_Final.pth"
35
 
 
212
 
213
  for filepath in tqdm(filepaths):
214
  video_data = read_video(filepath)
215
+ with timer("face blur filter"):
216
  frames = postprocessing_runner.postprocess(video_data.frames)
217
 
218
  output_path = os.path.join(args.output_dir, os.path.basename(filepath))
inference_utils.py CHANGED
@@ -28,7 +28,7 @@ from .model_t2w import DiffusionT2WModel
28
  from .model_v2w import DiffusionV2WModel
29
  from .config_helper import get_config_module, override
30
  from .utils_io import load_from_fileobj
31
- from .misc import arch_invariant_rand
32
 
33
  TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
34
  if TORCH_VERSION >= (1, 11):
@@ -418,7 +418,7 @@ def generate_world_from_text(
418
  3. Decodes latents to pixel space
419
  """
420
  x_sigma_max = (
421
- arch_invariant_rand(
422
  (1,) + tuple(state_shape),
423
  torch.float32,
424
  model.tensor_kwargs["device"],
@@ -484,7 +484,7 @@ def generate_world_from_video(
484
  num_of_latent_condition = compute_num_latent_frames(model, num_input_frames)
485
 
486
  x_sigma_max = (
487
- arch_invariant_rand(
488
  (1,) + tuple(state_shape),
489
  torch.float32,
490
  model.tensor_kwargs["device"],
 
28
  from .model_v2w import DiffusionV2WModel
29
  from .config_helper import get_config_module, override
30
  from .utils_io import load_from_fileobj
31
+ from .misc import misc
32
 
33
  TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
34
  if TORCH_VERSION >= (1, 11):
 
418
  3. Decodes latents to pixel space
419
  """
420
  x_sigma_max = (
421
+ misc.arch_invariant_rand(
422
  (1,) + tuple(state_shape),
423
  torch.float32,
424
  model.tensor_kwargs["device"],
 
484
  num_of_latent_condition = compute_num_latent_frames(model, num_input_frames)
485
 
486
  x_sigma_max = (
487
+ misc.arch_invariant_rand(
488
  (1,) + tuple(state_shape),
489
  torch.float32,
490
  model.tensor_kwargs["device"],
misc.py CHANGED
@@ -29,109 +29,115 @@ import numpy as np
29
  import termcolor
30
  import torch
31
 
32
- from . import distributed
33
 
34
 
35
- def to(
36
- data: Any,
37
- device: str | torch.device | None = None,
38
- dtype: torch.dtype | None = None,
39
- memory_format: torch.memory_format = torch.preserve_format,
40
- ) -> Any:
41
- """Recursively cast data into the specified device, dtype, and/or memory_format.
42
 
43
- The input data can be a tensor, a list of tensors, a dict of tensors.
44
- See the documentation for torch.Tensor.to() for details.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- Args:
47
- data (Any): Input data.
48
- device (str | torch.device): GPU device (default: None).
49
- dtype (torch.dtype): data type (default: None).
50
- memory_format (torch.memory_format): memory organization format (default: torch.preserve_format).
51
 
52
- Returns:
53
- data (Any): Data cast to the specified device, dtype, and/or memory_format.
54
- """
55
- assert (
56
- device is not None or dtype is not None or memory_format is not None
57
- ), "at least one of device, dtype, memory_format should be specified"
58
- if isinstance(data, torch.Tensor):
59
- is_cpu = (isinstance(device, str) and device == "cpu") or (
60
- isinstance(device, torch.device) and device.type == "cpu"
61
- )
62
- data = data.to(
63
- device=device,
64
- dtype=dtype,
65
- memory_format=memory_format,
66
- non_blocking=(not is_cpu),
67
- )
68
- return data
69
- elif isinstance(data, collections.abc.Mapping):
70
- return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data})
71
- elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
72
- return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data])
73
- else:
74
- return data
75
-
76
-
77
- def serialize(data: Any) -> Any:
78
- """Serialize data by hierarchically traversing through iterables.
79
-
80
- Args:
81
- data (Any): Input data.
82
-
83
- Returns:
84
- data (Any): Serialized data.
85
- """
86
- if isinstance(data, collections.abc.Mapping):
87
- return type(data)({key: serialize(data[key]) for key in data})
88
- elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
89
- return type(data)([serialize(elem) for elem in data])
90
- else:
91
- try:
92
- json.dumps(data)
93
- except TypeError:
94
- data = str(data)
95
- return data
96
-
97
-
98
- def set_random_seed(seed: int, by_rank: bool = False) -> None:
99
- """Set random seed. This includes random, numpy, Pytorch.
100
-
101
- Args:
102
- seed (int): Random seed.
103
- by_rank (bool): if true, each GPU will use a different random seed.
104
- """
105
- if by_rank:
106
- seed += distributed.get_rank()
107
- log.info(f"Using random seed {seed}.")
108
- random.seed(seed)
109
- np.random.seed(seed)
110
- torch.manual_seed(seed) # sets seed on the current CPU & all GPUs
111
-
112
-
113
- def arch_invariant_rand(
114
- shape: List[int] | Tuple[int], dtype: torch.dtype, device: str | torch.device, seed: int | None = None
115
- ):
116
- """Produce a GPU-architecture-invariant randomized Torch tensor.
117
-
118
- Args:
119
- shape (list or tuple of ints): Output tensor shape.
120
- dtype (torch.dtype): Output tensor type.
121
- device (torch.device): Device holding the output.
122
- seed (int): Optional randomization seed.
123
-
124
- Returns:
125
- tensor (torch.tensor): Randomly-generated tensor.
126
- """
127
- # Create a random number generator, optionally seeded
128
- rng = np.random.RandomState(seed)
129
 
130
- # # Generate random numbers using the generator
131
- random_array = rng.standard_normal(shape).astype(np.float32) # Use standard_normal for normal distribution
132
 
133
- # Convert to torch tensor and return
134
- return torch.from_numpy(random_array).to(dtype=dtype, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
 
137
  T = TypeVar("T", bound=Callable[..., Any])
 
29
  import termcolor
30
  import torch
31
 
32
+ from .distributed import get_rank
33
 
34
 
35
+ class misc():
 
 
 
 
 
 
36
 
37
+ @staticmethod
38
+ def to(
39
+ data: Any,
40
+ device: str | torch.device | None = None,
41
+ dtype: torch.dtype | None = None,
42
+ memory_format: torch.memory_format = torch.preserve_format,
43
+ ) -> Any:
44
+ """Recursively cast data into the specified device, dtype, and/or memory_format.
45
+
46
+ The input data can be a tensor, a list of tensors, a dict of tensors.
47
+ See the documentation for torch.Tensor.to() for details.
48
+
49
+ Args:
50
+ data (Any): Input data.
51
+ device (str | torch.device): GPU device (default: None).
52
+ dtype (torch.dtype): data type (default: None).
53
+ memory_format (torch.memory_format): memory organization format (default: torch.preserve_format).
54
+
55
+ Returns:
56
+ data (Any): Data cast to the specified device, dtype, and/or memory_format.
57
+ """
58
+ assert (
59
+ device is not None or dtype is not None or memory_format is not None
60
+ ), "at least one of device, dtype, memory_format should be specified"
61
+ if isinstance(data, torch.Tensor):
62
+ is_cpu = (isinstance(device, str) and device == "cpu") or (
63
+ isinstance(device, torch.device) and device.type == "cpu"
64
+ )
65
+ data = data.to(
66
+ device=device,
67
+ dtype=dtype,
68
+ memory_format=memory_format,
69
+ non_blocking=(not is_cpu),
70
+ )
71
+ return data
72
+ elif isinstance(data, collections.abc.Mapping):
73
+ return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data})
74
+ elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
75
+ return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data])
76
+ else:
77
+ return data
78
 
 
 
 
 
 
79
 
80
+ @staticmethod
81
+ def serialize(data: Any) -> Any:
82
+ """Serialize data by hierarchically traversing through iterables.
83
+
84
+ Args:
85
+ data (Any): Input data.
86
+
87
+ Returns:
88
+ data (Any): Serialized data.
89
+ """
90
+ if isinstance(data, collections.abc.Mapping):
91
+ return type(data)({key: serialize(data[key]) for key in data})
92
+ elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
93
+ return type(data)([serialize(elem) for elem in data])
94
+ else:
95
+ try:
96
+ json.dumps(data)
97
+ except TypeError:
98
+ data = str(data)
99
+ return data
100
+
101
+
102
+ @staticmethod
103
+ def set_random_seed(seed: int, by_rank: bool = False) -> None:
104
+ """Set random seed. This includes random, numpy, Pytorch.
105
+
106
+ Args:
107
+ seed (int): Random seed.
108
+ by_rank (bool): if true, each GPU will use a different random seed.
109
+ """
110
+ if by_rank:
111
+ seed += get_rank()
112
+ log.info(f"Using random seed {seed}.")
113
+ random.seed(seed)
114
+ np.random.seed(seed)
115
+ torch.manual_seed(seed) # sets seed on the current CPU & all GPUs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
 
 
117
 
118
+ @staticmethod
119
+ def arch_invariant_rand(
120
+ shape: List[int] | Tuple[int], dtype: torch.dtype, device: str | torch.device, seed: int | None = None
121
+ ):
122
+ """Produce a GPU-architecture-invariant randomized Torch tensor.
123
+
124
+ Args:
125
+ shape (list or tuple of ints): Output tensor shape.
126
+ dtype (torch.dtype): Output tensor type.
127
+ device (torch.device): Device holding the output.
128
+ seed (int): Optional randomization seed.
129
+
130
+ Returns:
131
+ tensor (torch.tensor): Randomly-generated tensor.
132
+ """
133
+ # Create a random number generator, optionally seeded
134
+ rng = np.random.RandomState(seed)
135
+
136
+ # # Generate random numbers using the generator
137
+ random_array = rng.standard_normal(shape).astype(np.float32) # Use standard_normal for normal distribution
138
+
139
+ # Convert to torch tensor and return
140
+ return torch.from_numpy(random_array).to(dtype=dtype, device=device)
141
 
142
 
143
  T = TypeVar("T", bound=Callable[..., Any])
model_t2w.py CHANGED
@@ -25,7 +25,7 @@ from .res_sampler import COMMON_SOLVER_OPTIONS, Sampler
25
  from .diffusion_types import DenoisePrediction
26
  from .blocks import FourierFeatures
27
  from .pretrained_vae import BaseVAE
28
- from . import misc
29
  from . import instantiate as lazy_instantiate
30
  from .log import log
31
 
@@ -96,7 +96,7 @@ class DiffusionT2WModel(torch.nn.Module):
96
  if hasattr(self.tokenizer, "reset_dtype"):
97
  self.tokenizer.reset_dtype()
98
 
99
- @misc.timer("DiffusionModel: set_up_model")
100
  def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format):
101
  """Initialize the core model components including network, conditioner and logvar."""
102
  self.model = self.build_model()
 
25
  from .diffusion_types import DenoisePrediction
26
  from .blocks import FourierFeatures
27
  from .pretrained_vae import BaseVAE
28
+ from .misc import misc, Color, timer
29
  from . import instantiate as lazy_instantiate
30
  from .log import log
31
 
 
96
  if hasattr(self.tokenizer, "reset_dtype"):
97
  self.tokenizer.reset_dtype()
98
 
99
+ @timer("DiffusionModel: set_up_model")
100
  def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format):
101
  """Initialize the core model components including network, conditioner and logvar."""
102
  self.model = self.build_model()
model_v2w.py CHANGED
@@ -24,7 +24,7 @@ from .conditioner import VideoExtendCondition
24
  from .config_base_conditioner import VideoCondBoolConfig
25
  from .batch_ops import batch_mul
26
  from .model_t2w import DiffusionT2WModel
27
- from . import misc
28
 
29
 
30
  @dataclass
 
24
  from .config_base_conditioner import VideoCondBoolConfig
25
  from .batch_ops import batch_mul
26
  from .model_t2w import DiffusionT2WModel
27
+ from .misc import misc, Color, timer
28
 
29
 
30
  @dataclass
text2world.py CHANGED
@@ -21,7 +21,7 @@ import torch
21
 
22
  from .inference_utils import add_common_arguments, validate_args
23
  from .world_generation_pipeline import DiffusionText2WorldGenerationPipeline
24
- from . import misc
25
  from .utils_io import read_prompts_from_file, save_video
26
 
27
  torch.enable_grad(False)
 
21
 
22
  from .inference_utils import add_common_arguments, validate_args
23
  from .world_generation_pipeline import DiffusionText2WorldGenerationPipeline
24
+ from .misc import misc, Color, timer
25
  from .utils_io import read_prompts_from_file, save_video
26
 
27
  torch.enable_grad(False)
text2world_hf.py CHANGED
@@ -6,7 +6,7 @@ from transformers import PreTrainedModel, PretrainedConfig
6
  from .inference_utils import add_common_arguments, validate_args
7
  from .world_generation_pipeline import DiffusionText2WorldGenerationPipeline
8
  from .log import log
9
- from . import misc
10
  from .utils_io import read_prompts_from_file, save_video
11
 
12
 
 
6
  from .inference_utils import add_common_arguments, validate_args
7
  from .world_generation_pipeline import DiffusionText2WorldGenerationPipeline
8
  from .log import log
9
+ from .misc import misc, Color, timer
10
  from .utils_io import read_prompts_from_file, save_video
11
 
12
 
video2world.py CHANGED
@@ -21,7 +21,7 @@ import torch
21
 
22
  from .inference_utils import add_common_arguments, check_input_frames, validate_args
23
  from .world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline
24
- from . import misc
25
  from .utils_io import read_prompts_from_file, save_video
26
 
27
  torch.enable_grad(False)
 
21
 
22
  from .inference_utils import add_common_arguments, check_input_frames, validate_args
23
  from .world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline
24
+ from .misc import misc, Color, timer
25
  from .utils_io import read_prompts_from_file, save_video
26
 
27
  torch.enable_grad(False)
video_content_safety_filter.py CHANGED
@@ -26,7 +26,7 @@ from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner
26
  from .guardrail_io_utils import get_video_filepaths, read_video
27
  from .video_content_safety_filter_model import ModelConfig, VideoSafetyModel
28
  from .video_content_safety_filter_vision_encoder import SigLIPEncoder
29
- from . import misc
30
 
31
  DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/video_content_safety_filter"
32
 
@@ -178,7 +178,7 @@ def main(args):
178
  runner = GuardrailRunner(safety_models=[video_filter], generic_safe_msg="Video is safe")
179
 
180
  for filepath in filepaths:
181
- with misc.timer("video content safety filter"):
182
  _ = runner.run_safety_check(filepath)
183
 
184
 
 
26
  from .guardrail_io_utils import get_video_filepaths, read_video
27
  from .video_content_safety_filter_model import ModelConfig, VideoSafetyModel
28
  from .video_content_safety_filter_vision_encoder import SigLIPEncoder
29
+ from .misc import misc, Color, timer
30
 
31
  DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/video_content_safety_filter"
32
 
 
178
  runner = GuardrailRunner(safety_models=[video_filter], generic_safe_msg="Video is safe")
179
 
180
  for filepath in filepaths:
181
+ with timer("video content safety filter"):
182
  _ = runner.run_safety_check(filepath)
183
 
184