EthanZyh commited on
Commit
812a75c
·
1 Parent(s): f29f716

try add . in import

Browse files
Files changed (1) hide show
  1. text2world_hf.py +27 -7
text2world_hf.py CHANGED
@@ -3,11 +3,11 @@ import argparse
3
  import torch
4
  from transformers import PreTrainedModel, PretrainedConfig
5
 
6
- from cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args
7
- from cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
8
- import cosmos1.utils.log as log
9
- import cosmos1.utils.misc as misc
10
- from cosmos1.utils.io import read_prompts_from_file, save_video
11
 
12
  class DiffusionText2WorldConfig(PretrainedConfig):
13
  model_type = "DiffusionText2World"
@@ -46,8 +46,28 @@ class DiffusionText2World(PreTrainedModel):
46
  torch.enable_grad(False) # TODO: do we need this?
47
  self.config = config
48
  inference_type = "text2world"
49
- validate_args(argparse.Namespace(**config), inference_type)
50
- self.pipeline = DiffusionText2WorldGenerationPipeline(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def forward(self, prompt):
53
  cfg = self.config
 
3
  import torch
4
  from transformers import PreTrainedModel, PretrainedConfig
5
 
6
+ from .cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args
7
+ from .cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
8
+ import .cosmos1.utils.log as log
9
+ import .cosmos1.utils.misc as misc
10
+ from .cosmos1.utils.io import read_prompts_from_file, save_video
11
 
12
  class DiffusionText2WorldConfig(PretrainedConfig):
13
  model_type = "DiffusionText2World"
 
46
  torch.enable_grad(False) # TODO: do we need this?
47
  self.config = config
48
  inference_type = "text2world"
49
+ config.prompt = 1 # TODO: this is to hack args validation, maybe find a better way
50
+ validate_args(config, inference_type)
51
+ del config.prompt
52
+ self.pipeline = DiffusionText2WorldGenerationPipeline(
53
+ inference_type=inference_type,
54
+ checkpoint_dir=config.checkpoint_dir,
55
+ checkpoint_name=config.diffusion_transformer_dir,
56
+ prompt_upsampler_dir=config.prompt_upsampler_dir,
57
+ enable_prompt_upsampler=not config.disable_prompt_upsampler,
58
+ offload_network=config.offload_diffusion_transformer,
59
+ offload_tokenizer=config.offload_tokenizer,
60
+ offload_text_encoder_model=config.offload_text_encoder_model,
61
+ offload_prompt_upsampler=config.offload_prompt_upsampler,
62
+ offload_guardrail_models=config.offload_guardrail_models,
63
+ guidance=config.guidance,
64
+ num_steps=config.num_steps,
65
+ height=config.height,
66
+ width=config.width,
67
+ fps=config.fps,
68
+ num_video_frames=config.num_video_frames,
69
+ seed=config.seed,
70
+ )
71
 
72
  def forward(self, prompt):
73
  cfg = self.config