roll-ai commited on
Commit
1981bb4
·
verified ·
1 Parent(s): 76b7d7b

Update inference/flovd_demo.py

Browse files
Files changed (1) hide show
  1. inference/flovd_demo.py +77 -34
inference/flovd_demo.py CHANGED
@@ -92,64 +92,107 @@ RESOLUTION_MAP = {
92
  "cogvideox-2b": (480, 720),
93
  }
94
 
95
-
96
-
97
- def load_cogvideox_flovd_FVSM_controlnet_pipeline(controlnet_path, backbone_path, device, dtype):
98
- controlnet_sd = torch.load(controlnet_path, map_location='cpu')['module']
99
-
100
  tokenizer = AutoTokenizer.from_pretrained(backbone_path, subfolder="tokenizer")
101
- text_encoder = T5EncoderModel.from_pretrained(backbone_path, subfolder="text_encoder", device_map="auto", torch_dtype=dtype)
102
- transformer = CogVideoXTransformer3DModel.from_pretrained(backbone_path, subfolder="transformer", device_map="auto", torch_dtype=dtype)
103
- vae = AutoencoderKLCogVideoX.from_pretrained(backbone_path, subfolder="vae", device_map="auto", torch_dtype=dtype)
 
 
 
 
104
  scheduler = CogVideoXDPMScheduler.from_pretrained(backbone_path, subfolder="scheduler")
105
 
106
- controlnet = CogVideoXControlnet.from_pretrained(backbone_path, subfolder="transformer", num_layers=6, out_proj_dim_factor=64, out_proj_dim_zero_init=True, notextinflow=True)
107
- controlnet.eval()
108
- controlnet.load_state_dict(controlnet_sd)
109
 
110
- pipe = FloVDCogVideoXControlnetImageToVideoPipeline(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  tokenizer=tokenizer,
112
  text_encoder=text_encoder,
113
  vae=vae,
114
  transformer=transformer,
115
- controlnet=controlnet,
116
  scheduler=scheduler,
117
- ).to(device, dtype)
 
 
 
118
 
119
  return pipe
120
 
121
 
122
- def load_cogvideox_flovd_OMSM_lora_pipeline(omsm_path, backbone_path, transformer_lora_config, device, dtype):
 
 
123
  tokenizer = AutoTokenizer.from_pretrained(backbone_path, subfolder="tokenizer")
124
- text_encoder = T5EncoderModel.from_pretrained(backbone_path, subfolder="text_encoder", device_map="auto", torch_dtype=dtype)
125
- transformer = CogVideoXTransformer3DModel.from_pretrained(backbone_path, subfolder="transformer", device_map="auto", torch_dtype=dtype)
126
- vae = AutoencoderKLCogVideoX.from_pretrained(backbone_path, subfolder="vae", device_map="auto", torch_dtype=dtype)
 
 
 
 
127
  scheduler = CogVideoXDPMScheduler.from_pretrained(backbone_path, subfolder="scheduler")
128
 
129
- transformer.add_adapter(transformer_lora_config)
130
- lora_state_dict = FloVDOMSMCogVideoXImageToVideoPipeline.lora_state_dict(omsm_path)
131
- transformer_state_dict = {
132
- f'{k.replace("transformer.", "")}': v
133
- for k, v in lora_state_dict.items()
134
- if k.startswith("transformer.")
135
  }
136
- set_peft_model_state_dict(transformer, transformer_state_dict, adapter_name="default")
 
137
 
138
- other_block_path = os.path.join(omsm_path, "selected_blocks.safetensors")
139
- if os.path.exists(other_block_path):
140
- tensor_dict = load_file(other_block_path)
141
- for block_name, param in tensor_dict.items():
142
- module_name, param_name = block_name.split(".", 1)
143
- if hasattr(transformer, module_name):
144
- getattr(transformer, module_name).load_state_dict({param_name: param}, strict=False)
145
 
146
- pipe = FloVDOMSMCogVideoXImageToVideoPipeline(
 
 
 
 
147
  tokenizer=tokenizer,
148
  text_encoder=text_encoder,
149
  vae=vae,
150
  transformer=transformer,
 
151
  scheduler=scheduler,
152
- ).to(device, dtype)
 
 
 
153
 
154
  return pipe
155
 
 
92
  "cogvideox-2b": (480, 720),
93
  }
94
 
95
+ def load_cogvideox_flovd_OMSM_lora_pipeline(omsm_path, backbone_path, transformer_lora_config, device, dtype):
 
 
 
 
96
  tokenizer = AutoTokenizer.from_pretrained(backbone_path, subfolder="tokenizer")
97
+ text_encoder = T5EncoderModel.from_pretrained(backbone_path, subfolder="text_encoder")
98
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
99
+ backbone_path, subfolder="transformer", torch_dtype=dtype, device_map="auto"
100
+ )
101
+ vae = AutoencoderKLCogVideoX.from_pretrained(
102
+ backbone_path, subfolder="vae", torch_dtype=dtype, device_map="auto"
103
+ )
104
  scheduler = CogVideoXDPMScheduler.from_pretrained(backbone_path, subfolder="scheduler")
105
 
106
+ # 1) Load Lora weight
107
+ transformer.add_adapter(transformer_lora_config)
 
108
 
109
+ lora_state_dict = FloVDOMSMCogVideoXImageToVideoPipeline.lora_state_dict(omsm_path)
110
+ transformer_state_dict = {
111
+ f'{k.replace("transformer.", "")}': v
112
+ for k, v in lora_state_dict.items()
113
+ if k.startswith("transformer.")
114
+ }
115
+ incompatible_keys = set_peft_model_state_dict(transformer, transformer_state_dict, adapter_name="default")
116
+ if incompatible_keys is not None:
117
+ # check only for unexpected keys
118
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
119
+ if unexpected_keys:
120
+ logger.warning(
121
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
122
+ f" {unexpected_keys}. "
123
+ )
124
+
125
+ # 2) Load Other weight
126
+ load_path = os.path.join(omsm_path, "selected_blocks.safetensors")
127
+ if os.path.exists(load_path):
128
+ tensor_dict = load_file(load_path)
129
+
130
+ block_state_dicts = {}
131
+ for k, v in tensor_dict.items():
132
+ block_name, param_name = k.split(".", 1)
133
+ if block_name not in block_state_dicts:
134
+ block_state_dicts[block_name] = {}
135
+ block_state_dicts[block_name][param_name] = v
136
+
137
+ for block_name, state_dict in block_state_dicts.items():
138
+ if hasattr(transformer, block_name):
139
+ getattr(transformer, block_name).load_state_dict(state_dict)
140
+ else:
141
+ raise ValueError(f"Transformer has no attribute '{block_name}'")
142
+
143
+ pipe = FloVDOMSMCogVideoXImageToVideoPipeline(
144
  tokenizer=tokenizer,
145
  text_encoder=text_encoder,
146
  vae=vae,
147
  transformer=transformer,
 
148
  scheduler=scheduler,
149
+ )
150
+
151
+ pipe.vae.enable_slicing()
152
+ pipe.vae.enable_tiling()
153
 
154
  return pipe
155
 
156
 
157
+ def load_cogvideox_flovd_FVSM_controlnet_pipeline(controlnet_path, backbone_path, device, dtype):
158
+ controlnet_sd = torch.load(controlnet_path, map_location='cpu')['module']
159
+
160
  tokenizer = AutoTokenizer.from_pretrained(backbone_path, subfolder="tokenizer")
161
+ text_encoder = T5EncoderModel.from_pretrained(backbone_path, subfolder="text_encoder")
162
+ transformer = CustomCogVideoXTransformer3DModel.from_pretrained(
163
+ backbone_path, subfolder="transformer", torch_dtype=dtype, device_map="auto"
164
+ )
165
+ vae = AutoencoderKLCogVideoX.from_pretrained(
166
+ backbone_path, subfolder="vae", torch_dtype=dtype, device_map="auto"
167
+ )
168
  scheduler = CogVideoXDPMScheduler.from_pretrained(backbone_path, subfolder="scheduler")
169
 
170
+ additional_kwargs = {
171
+ 'num_layers': 6,
172
+ 'out_proj_dim_factor': 64,
173
+ 'out_proj_dim_zero_init': True,
174
+ 'notextinflow': True,
 
175
  }
176
+ controlnet = CogVideoXControlnet.from_pretrained(backbone_path, subfolder="transformer", **additional_kwargs)
177
+ controlnet.eval()
178
 
179
+ missing, unexpected = controlnet.load_state_dict(controlnet_sd)
 
 
 
 
 
 
180
 
181
+ if len(missing) != 0 or len(unexpected) != 0:
182
+ print(f"Missing keys : {missing}")
183
+ print(f"Unexpected keys : {unexpected}")
184
+
185
+ pipe = FloVDCogVideoXControlnetImageToVideoPipeline(
186
  tokenizer=tokenizer,
187
  text_encoder=text_encoder,
188
  vae=vae,
189
  transformer=transformer,
190
+ controlnet=controlnet,
191
  scheduler=scheduler,
192
+ )
193
+
194
+ pipe.vae.enable_slicing()
195
+ pipe.vae.enable_tiling()
196
 
197
  return pipe
198