EthanZyh commited on
Commit
765a5bb
·
1 Parent(s): 62c1236

modify presets

Browse files
base_world_generation_pipeline.py CHANGED
@@ -22,7 +22,7 @@ import numpy as np
22
  import torch
23
 
24
  from .t5_text_encoder import CosmosT5TextEncoder
25
- from . import presets as guardrail_presets
26
 
27
 
28
  class BaseWorldGenerationPipeline(ABC):
 
22
  import torch
23
 
24
  from .t5_text_encoder import CosmosT5TextEncoder
25
+ from .presets import presets as guardrail_presets
26
 
27
 
28
  class BaseWorldGenerationPipeline(ABC):
presets.py CHANGED
@@ -25,53 +25,59 @@ from .video_content_safety_filter import VideoContentSafetyFilter
25
  from .log import log
26
 
27
 
28
- def create_text_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
29
- """Create the text guardrail runner."""
30
- blocklist_checkpoint_dir = os.path.join(checkpoint_dir, "blocklist")
31
- aegis_checkpoint_dir = os.path.join(checkpoint_dir, "aegis")
32
- return GuardrailRunner(safety_models=[Blocklist(blocklist_checkpoint_dir), Aegis(aegis_checkpoint_dir)])
33
-
34
-
35
- def create_video_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
36
- """Create the video guardrail runner."""
37
- video_filter_checkpoint_dir = os.path.join(checkpoint_dir, "video_content_safety_filter")
38
- retinaface_checkpoint_path = os.path.join(checkpoint_dir, "face_blur_filter/Resnet50_Final.pth")
39
- return GuardrailRunner(
40
- safety_models=[VideoContentSafetyFilter(video_filter_checkpoint_dir)],
41
- postprocessors=[RetinaFaceFilter(retinaface_checkpoint_path)],
42
- )
43
-
44
-
45
- def run_text_guardrail(prompt: str, guardrail_runner: GuardrailRunner) -> bool:
46
- """Run the text guardrail on the prompt, checking for content safety.
47
-
48
- Args:
49
- prompt: The text prompt.
50
- guardrail_runner: The text guardrail runner.
51
-
52
- Returns:
53
- bool: Whether the prompt is safe.
54
- """
55
- is_safe, message = guardrail_runner.run_safety_check(prompt)
56
- if not is_safe:
57
- log.critical(f"GUARDRAIL BLOCKED: {message}")
58
- return is_safe
59
-
60
-
61
- def run_video_guardrail(frames: np.ndarray, guardrail_runner: GuardrailRunner) -> np.ndarray | None:
62
- """Run the video guardrail on the frames, checking for content safety and applying face blur.
63
-
64
- Args:
65
- frames: The frames of the generated video.
66
- guardrail_runner: The video guardrail runner.
67
-
68
- Returns:
69
- The processed frames if safe, otherwise None.
70
- """
71
- is_safe, message = guardrail_runner.run_safety_check(frames)
72
- if not is_safe:
73
- log.critical(f"GUARDRAIL BLOCKED: {message}")
74
- return None
75
-
76
- frames = guardrail_runner.postprocess(frames)
77
- return frames
 
 
 
 
 
 
 
25
  from .log import log
26
 
27
 
28
+ class presets():
29
+
30
+ @staticmethod
31
+ def create_text_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
32
+ """Create the text guardrail runner."""
33
+ blocklist_checkpoint_dir = os.path.join(checkpoint_dir, "blocklist")
34
+ aegis_checkpoint_dir = os.path.join(checkpoint_dir, "aegis")
35
+ return GuardrailRunner(safety_models=[Blocklist(blocklist_checkpoint_dir), Aegis(aegis_checkpoint_dir)])
36
+
37
+
38
+ @staticmethod
39
+ def create_video_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
40
+ """Create the video guardrail runner."""
41
+ video_filter_checkpoint_dir = os.path.join(checkpoint_dir, "video_content_safety_filter")
42
+ retinaface_checkpoint_path = os.path.join(checkpoint_dir, "face_blur_filter/Resnet50_Final.pth")
43
+ return GuardrailRunner(
44
+ safety_models=[VideoContentSafetyFilter(video_filter_checkpoint_dir)],
45
+ postprocessors=[RetinaFaceFilter(retinaface_checkpoint_path)],
46
+ )
47
+
48
+
49
+ @staticmethod
50
+ def run_text_guardrail(prompt: str, guardrail_runner: GuardrailRunner) -> bool:
51
+ """Run the text guardrail on the prompt, checking for content safety.
52
+
53
+ Args:
54
+ prompt: The text prompt.
55
+ guardrail_runner: The text guardrail runner.
56
+
57
+ Returns:
58
+ bool: Whether the prompt is safe.
59
+ """
60
+ is_safe, message = guardrail_runner.run_safety_check(prompt)
61
+ if not is_safe:
62
+ log.critical(f"GUARDRAIL BLOCKED: {message}")
63
+ return is_safe
64
+
65
+
66
+ @staticmethod
67
+ def run_video_guardrail(frames: np.ndarray, guardrail_runner: GuardrailRunner) -> np.ndarray | None:
68
+ """Run the video guardrail on the frames, checking for content safety and applying face blur.
69
+
70
+ Args:
71
+ frames: The frames of the generated video.
72
+ guardrail_runner: The video guardrail runner.
73
+
74
+ Returns:
75
+ The processed frames if safe, otherwise None.
76
+ """
77
+ is_safe, message = guardrail_runner.run_safety_check(frames)
78
+ if not is_safe:
79
+ log.critical(f"GUARDRAIL BLOCKED: {message}")
80
+ return None
81
+
82
+ frames = guardrail_runner.postprocess(frames)
83
+ return frames
text2world_prompt_upsampler_inference.py CHANGED
@@ -26,7 +26,7 @@ import re
26
  from .model_config import create_text_model_config
27
  from .ar_model import AutoRegressiveModel
28
  from .inference import chat_completion
29
- from . import presets as guardrail_presets
30
  from .log import log
31
 
32
 
 
26
  from .model_config import create_text_model_config
27
  from .ar_model import AutoRegressiveModel
28
  from .inference import chat_completion
29
+ from .presets import presets as guardrail_presets
30
  from .log import log
31
 
32
 
video2world_prompt_upsampler_inference.py CHANGED
@@ -29,7 +29,7 @@ from PIL import Image
29
  from .model_config import create_vision_language_model_config
30
  from .ar_model import AutoRegressiveModel
31
  from .inference import chat_completion
32
- from . import presets as guardrail_presets
33
  from .log import log
34
  from .utils_io import load_from_fileobj
35
 
 
29
  from .model_config import create_vision_language_model_config
30
  from .ar_model import AutoRegressiveModel
31
  from .inference import chat_completion
32
+ from .presets import presets as guardrail_presets
33
  from .log import log
34
  from .utils_io import load_from_fileobj
35