try add . in import
Browse files- text2world_hf.py +27 -7
text2world_hf.py
CHANGED
@@ -3,11 +3,11 @@ import argparse
|
|
3 |
import torch
|
4 |
from transformers import PreTrainedModel, PretrainedConfig
|
5 |
|
6 |
-
from cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args
|
7 |
-
from cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
|
8 |
-
import cosmos1.utils.log as log
|
9 |
-
import cosmos1.utils.misc as misc
|
10 |
-
from cosmos1.utils.io import read_prompts_from_file, save_video
|
11 |
|
12 |
class DiffusionText2WorldConfig(PretrainedConfig):
|
13 |
model_type = "DiffusionText2World"
|
@@ -46,8 +46,28 @@ class DiffusionText2World(PreTrainedModel):
|
|
46 |
torch.enable_grad(False) # TODO: do we need this?
|
47 |
self.config = config
|
48 |
inference_type = "text2world"
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
def forward(self, prompt):
|
53 |
cfg = self.config
|
|
|
3 |
import torch
|
4 |
from transformers import PreTrainedModel, PretrainedConfig
|
5 |
|
6 |
+
from .cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args
|
7 |
+
from .cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
|
8 |
+
import .cosmos1.utils.log as log
|
9 |
+
import .cosmos1.utils.misc as misc
|
10 |
+
from .cosmos1.utils.io import read_prompts_from_file, save_video
|
11 |
|
12 |
class DiffusionText2WorldConfig(PretrainedConfig):
|
13 |
model_type = "DiffusionText2World"
|
|
|
46 |
torch.enable_grad(False) # TODO: do we need this?
|
47 |
self.config = config
|
48 |
inference_type = "text2world"
|
49 |
+
config.prompt = 1 # TODO: this is to hack args validation, maybe find a better way
|
50 |
+
validate_args(config, inference_type)
|
51 |
+
del config.prompt
|
52 |
+
self.pipeline = DiffusionText2WorldGenerationPipeline(
|
53 |
+
inference_type=inference_type,
|
54 |
+
checkpoint_dir=config.checkpoint_dir,
|
55 |
+
checkpoint_name=config.diffusion_transformer_dir,
|
56 |
+
prompt_upsampler_dir=config.prompt_upsampler_dir,
|
57 |
+
enable_prompt_upsampler=not config.disable_prompt_upsampler,
|
58 |
+
offload_network=config.offload_diffusion_transformer,
|
59 |
+
offload_tokenizer=config.offload_tokenizer,
|
60 |
+
offload_text_encoder_model=config.offload_text_encoder_model,
|
61 |
+
offload_prompt_upsampler=config.offload_prompt_upsampler,
|
62 |
+
offload_guardrail_models=config.offload_guardrail_models,
|
63 |
+
guidance=config.guidance,
|
64 |
+
num_steps=config.num_steps,
|
65 |
+
height=config.height,
|
66 |
+
width=config.width,
|
67 |
+
fps=config.fps,
|
68 |
+
num_video_frames=config.num_video_frames,
|
69 |
+
seed=config.seed,
|
70 |
+
)
|
71 |
|
72 |
def forward(self, prompt):
|
73 |
cfg = self.config
|