File size: 6,036 Bytes
85c9b8b
 
 
 
 
8316c81
 
 
 
 
85c9b8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
812a75c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85c9b8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import argparse
import torch
from transformers import PreTrainedModel, PretrainedConfig

from .cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args
from .cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
import .cosmos1.utils.log as log 
import .cosmos1.utils.misc as misc
from .cosmos1.utils.io import read_prompts_from_file, save_video

class DiffusionText2WorldConfig(PretrainedConfig):
    model_type = "DiffusionText2World" 
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.diffusion_transformer_dir = kwargs.get("diffusion_transformer_dir", "Cosmos-1.0-Diffusion-7B-Text2World")
        self.prompt_upsampler_dir = kwargs.get("prompt_upsampler_dir", "Cosmos-1.0-Prompt-Upsampler-12B-Text2World")
        self.word_limit_to_skip_upsampler = kwargs.get("word_limit_to_skip_upsampler", 250)
        self.checkpoint_dir = kwargs.get("checkpoint_dir", "checkpoints")
        self.tokenizer_dir = kwargs.get("tokenizer_dir", "Cosmos-1.0-Tokenizer-CV8x8x8")
        self.video_save_name = kwargs.get("video_save_name", "output")
        self.video_save_folder = kwargs.get("video_save_folder", "outputs/")
        self.prompt = kwargs.get("prompt", None)
        self.batch_input_path = kwargs.get("batch_input_path", None)
        self.negative_prompt = kwargs.get("negative_prompt", None)
        self.num_steps = kwargs.get("num_steps", 35)
        self.guidance = kwargs.get("guidance", 7)
        self.num_video_frames = kwargs.get("num_video_frames", 121)
        self.height = kwargs.get("height", 704)
        self.width = kwargs.get("width", 1280)
        self.fps = kwargs.get("fps", 24)
        self.seed = kwargs.get("seed", 1)
        self.disable_prompt_upsampler = kwargs.get("disable_prompt_upsampler", False)
        self.offload_diffusion_transformer = kwargs.get("offload_diffusion_transformer", False)
        self.offload_tokenizer = kwargs.get("offload_tokenizer", False)
        self.offload_text_encoder_model = kwargs.get("offload_text_encoder_model", False)
        self.offload_prompt_upsampler = kwargs.get("offload_prompt_upsampler", False)
        self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)


class DiffusionText2World(PreTrainedModel):
    config_class = DiffusionText2WorldConfig

    def __init__(self, config=DiffusionText2WorldConfig()):
        super().__init__(config)
        torch.enable_grad(False)   # TODO: do we need this?
        self.config = config
        inference_type = "text2world"
        config.prompt = 1          # TODO: this is to hack args validation, maybe find a better way 
        validate_args(config, inference_type) 
        del config.prompt
        self.pipeline = DiffusionText2WorldGenerationPipeline(
            inference_type=inference_type,
            checkpoint_dir=config.checkpoint_dir,
            checkpoint_name=config.diffusion_transformer_dir,
            prompt_upsampler_dir=config.prompt_upsampler_dir,
            enable_prompt_upsampler=not config.disable_prompt_upsampler,
            offload_network=config.offload_diffusion_transformer,
            offload_tokenizer=config.offload_tokenizer,
            offload_text_encoder_model=config.offload_text_encoder_model,
            offload_prompt_upsampler=config.offload_prompt_upsampler,
            offload_guardrail_models=config.offload_guardrail_models,
            guidance=config.guidance,
            num_steps=config.num_steps,
            height=config.height,
            width=config.width,
            fps=config.fps,
            num_video_frames=config.num_video_frames,
            seed=config.seed,
        )

    def forward(self, prompt):
        cfg = self.config
        # Handle multiple prompts if prompt file is provided
        if cfg.batch_input_path:
            log.info(f"Reading batch inputs from path: {cfg.batch_input_path}")
            prompts = read_prompts_from_file(cfg.batch_input_path)
        else:
            # Single prompt case
            prompts = [{"prompt": cfg.prompt}]

        os.makedirs(cfg.video_save_folder, exist_ok=True)
        for i, input_dict in enumerate(prompts):
            current_prompt = input_dict.get("prompt", None)
            if current_prompt is None:
                log.critical("Prompt is missing, skipping world generation.")
                continue

            # Generate video
            generated_output = self.pipeline.generate(current_prompt, cfg.negative_prompt, cfg.word_limit_to_skip_upsampler)
            if generated_output is None:
                log.critical("Guardrail blocked text2world generation.")
                continue
            video, prompt = generated_output

            if cfg.batch_input_path:
                video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4")
                prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt")
            else:
                video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
                prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")

            # Save video
            save_video(
                video=video,
                fps=cfg.fps,
                H=cfg.height,
                W=cfg.width,
                video_save_quality=5,
                video_save_path=video_save_path,
            )

            # Save prompt to text file alongside video
            with open(prompt_save_path, "wb") as f:
                f.write(prompt.encode("utf-8"))

            log.info(f"Saved video to {video_save_path}")
            log.info(f"Saved prompt to {prompt_save_path}")
    
    def save_pretrained(self, save_directory, **kwargs): 
        # We don't save anything
        pass

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        config = kwargs["config"]
        model = cls(config)
        return model