File size: 6,085 Bytes
85c9b8b f9c7c59 85c9b8b 812a75c 85c9b8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import argparse
import torch
from transformers import PreTrainedModel, PretrainedConfig
from .testimport1 import testA, testB
# from .cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args
# from .cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
# import .cosmos1.utils.log as log
# import .cosmos1.utils.misc as misc
# from .cosmos1.utils.io import read_prompts_from_file, save_video
class DiffusionText2WorldConfig(PretrainedConfig):
model_type = "DiffusionText2World"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.diffusion_transformer_dir = kwargs.get("diffusion_transformer_dir", "Cosmos-1.0-Diffusion-7B-Text2World")
self.prompt_upsampler_dir = kwargs.get("prompt_upsampler_dir", "Cosmos-1.0-Prompt-Upsampler-12B-Text2World")
self.word_limit_to_skip_upsampler = kwargs.get("word_limit_to_skip_upsampler", 250)
self.checkpoint_dir = kwargs.get("checkpoint_dir", "checkpoints")
self.tokenizer_dir = kwargs.get("tokenizer_dir", "Cosmos-1.0-Tokenizer-CV8x8x8")
self.video_save_name = kwargs.get("video_save_name", "output")
self.video_save_folder = kwargs.get("video_save_folder", "outputs/")
self.prompt = kwargs.get("prompt", None)
self.batch_input_path = kwargs.get("batch_input_path", None)
self.negative_prompt = kwargs.get("negative_prompt", None)
self.num_steps = kwargs.get("num_steps", 35)
self.guidance = kwargs.get("guidance", 7)
self.num_video_frames = kwargs.get("num_video_frames", 121)
self.height = kwargs.get("height", 704)
self.width = kwargs.get("width", 1280)
self.fps = kwargs.get("fps", 24)
self.seed = kwargs.get("seed", 1)
self.disable_prompt_upsampler = kwargs.get("disable_prompt_upsampler", False)
self.offload_diffusion_transformer = kwargs.get("offload_diffusion_transformer", False)
self.offload_tokenizer = kwargs.get("offload_tokenizer", False)
self.offload_text_encoder_model = kwargs.get("offload_text_encoder_model", False)
self.offload_prompt_upsampler = kwargs.get("offload_prompt_upsampler", False)
self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
class DiffusionText2World(PreTrainedModel):
config_class = DiffusionText2WorldConfig
def __init__(self, config=DiffusionText2WorldConfig()):
super().__init__(config)
torch.enable_grad(False) # TODO: do we need this?
self.config = config
inference_type = "text2world"
config.prompt = 1 # TODO: this is to hack args validation, maybe find a better way
validate_args(config, inference_type)
del config.prompt
self.pipeline = DiffusionText2WorldGenerationPipeline(
inference_type=inference_type,
checkpoint_dir=config.checkpoint_dir,
checkpoint_name=config.diffusion_transformer_dir,
prompt_upsampler_dir=config.prompt_upsampler_dir,
enable_prompt_upsampler=not config.disable_prompt_upsampler,
offload_network=config.offload_diffusion_transformer,
offload_tokenizer=config.offload_tokenizer,
offload_text_encoder_model=config.offload_text_encoder_model,
offload_prompt_upsampler=config.offload_prompt_upsampler,
offload_guardrail_models=config.offload_guardrail_models,
guidance=config.guidance,
num_steps=config.num_steps,
height=config.height,
width=config.width,
fps=config.fps,
num_video_frames=config.num_video_frames,
seed=config.seed,
)
def forward(self, prompt):
cfg = self.config
# Handle multiple prompts if prompt file is provided
if cfg.batch_input_path:
log.info(f"Reading batch inputs from path: {cfg.batch_input_path}")
prompts = read_prompts_from_file(cfg.batch_input_path)
else:
# Single prompt case
prompts = [{"prompt": cfg.prompt}]
os.makedirs(cfg.video_save_folder, exist_ok=True)
for i, input_dict in enumerate(prompts):
current_prompt = input_dict.get("prompt", None)
if current_prompt is None:
log.critical("Prompt is missing, skipping world generation.")
continue
# Generate video
generated_output = self.pipeline.generate(current_prompt, cfg.negative_prompt, cfg.word_limit_to_skip_upsampler)
if generated_output is None:
log.critical("Guardrail blocked text2world generation.")
continue
video, prompt = generated_output
if cfg.batch_input_path:
video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4")
prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt")
else:
video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
# Save video
save_video(
video=video,
fps=cfg.fps,
H=cfg.height,
W=cfg.width,
video_save_quality=5,
video_save_path=video_save_path,
)
# Save prompt to text file alongside video
with open(prompt_save_path, "wb") as f:
f.write(prompt.encode("utf-8"))
log.info(f"Saved video to {video_save_path}")
log.info(f"Saved prompt to {prompt_save_path}")
def save_pretrained(self, save_directory, **kwargs):
# We don't save anything
pass
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs["config"]
model = cls(config)
return model |