Update inference/flovd_demo.py
Browse files- inference/flovd_demo.py +36 -77
inference/flovd_demo.py
CHANGED
@@ -94,104 +94,63 @@ RESOLUTION_MAP = {
|
|
94 |
|
95 |
|
96 |
|
97 |
-
|
98 |
def load_cogvideox_flovd_FVSM_controlnet_pipeline(controlnet_path, backbone_path, device, dtype):
|
99 |
controlnet_sd = torch.load(controlnet_path, map_location='cpu')['module']
|
100 |
-
|
101 |
tokenizer = AutoTokenizer.from_pretrained(backbone_path, subfolder="tokenizer")
|
102 |
-
text_encoder = T5EncoderModel.from_pretrained(backbone_path, subfolder="text_encoder")
|
103 |
-
transformer =
|
104 |
-
vae = AutoencoderKLCogVideoX.from_pretrained(backbone_path, subfolder="vae")
|
105 |
scheduler = CogVideoXDPMScheduler.from_pretrained(backbone_path, subfolder="scheduler")
|
106 |
-
|
107 |
-
|
108 |
-
'num_layers': 6,
|
109 |
-
'out_proj_dim_factor': 64,
|
110 |
-
'out_proj_dim_zero_init': True,
|
111 |
-
'notextinflow': True,
|
112 |
-
}
|
113 |
-
controlnet = CogVideoXControlnet.from_pretrained(backbone_path, subfolder="transformer", **additional_kwargs)
|
114 |
controlnet.eval()
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
if len(missing) != 0 or len(unexpected) != 0:
|
119 |
-
print(f"Missing keys : {missing}")
|
120 |
-
print(f"Unexpected keys : {unexpected}")
|
121 |
-
|
122 |
pipe = FloVDCogVideoXControlnetImageToVideoPipeline(
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
)
|
130 |
-
|
131 |
-
# pipe.enable_model_cpu_offload(device=device)
|
132 |
-
pipe = pipe.to(device, dtype)
|
133 |
-
|
134 |
return pipe
|
135 |
|
|
|
136 |
def load_cogvideox_flovd_OMSM_lora_pipeline(omsm_path, backbone_path, transformer_lora_config, device, dtype):
|
137 |
tokenizer = AutoTokenizer.from_pretrained(backbone_path, subfolder="tokenizer")
|
138 |
-
text_encoder = T5EncoderModel.from_pretrained(backbone_path, subfolder="text_encoder")
|
139 |
-
transformer = CogVideoXTransformer3DModel.from_pretrained(backbone_path, subfolder="transformer")
|
140 |
-
vae = AutoencoderKLCogVideoX.from_pretrained(backbone_path, subfolder="vae")
|
141 |
scheduler = CogVideoXDPMScheduler.from_pretrained(backbone_path, subfolder="scheduler")
|
142 |
|
143 |
-
# 1) Load Lora weight
|
144 |
transformer.add_adapter(transformer_lora_config)
|
145 |
-
|
146 |
lora_state_dict = FloVDOMSMCogVideoXImageToVideoPipeline.lora_state_dict(omsm_path)
|
147 |
transformer_state_dict = {
|
148 |
f'{k.replace("transformer.", "")}': v
|
149 |
for k, v in lora_state_dict.items()
|
150 |
if k.startswith("transformer.")
|
151 |
}
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
|
162 |
-
# 2) Load Other weight
|
163 |
-
load_path = os.path.join(omsm_path, "selected_blocks.safetensors")
|
164 |
-
if os.path.exists(load_path):
|
165 |
-
tensor_dict = load_file(load_path)
|
166 |
-
|
167 |
-
block_state_dicts = {}
|
168 |
-
for k, v in tensor_dict.items():
|
169 |
-
block_name, param_name = k.split(".", 1)
|
170 |
-
if block_name not in block_state_dicts:
|
171 |
-
block_state_dicts[block_name] = {}
|
172 |
-
block_state_dicts[block_name][param_name] = v
|
173 |
-
|
174 |
-
for block_name, state_dict in block_state_dicts.items():
|
175 |
-
if hasattr(transformer, block_name):
|
176 |
-
getattr(transformer, block_name).load_state_dict(state_dict)
|
177 |
-
else:
|
178 |
-
raise ValueError(f"Transformer has no attribute '{block_name}'")
|
179 |
-
|
180 |
-
|
181 |
pipe = FloVDOMSMCogVideoXImageToVideoPipeline(
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
)
|
188 |
-
|
189 |
-
# pipe.load_lora_weights(omsm_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
|
190 |
-
# pipe.fuse_lora(components=["transformer"], lora_scale=1.0)
|
191 |
-
|
192 |
-
# pipe.enable_model_cpu_offload(device=device)
|
193 |
-
pipe = pipe.to(device, dtype)
|
194 |
-
|
195 |
return pipe
|
196 |
|
197 |
|
|
|
94 |
|
95 |
|
96 |
|
|
|
97 |
def load_cogvideox_flovd_FVSM_controlnet_pipeline(controlnet_path, backbone_path, device, dtype):
|
98 |
controlnet_sd = torch.load(controlnet_path, map_location='cpu')['module']
|
99 |
+
|
100 |
tokenizer = AutoTokenizer.from_pretrained(backbone_path, subfolder="tokenizer")
|
101 |
+
text_encoder = T5EncoderModel.from_pretrained(backbone_path, subfolder="text_encoder", device_map="auto", torch_dtype=dtype)
|
102 |
+
transformer = CogVideoXTransformer3DModel.from_pretrained(backbone_path, subfolder="transformer", device_map="auto", torch_dtype=dtype)
|
103 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(backbone_path, subfolder="vae", device_map="auto", torch_dtype=dtype)
|
104 |
scheduler = CogVideoXDPMScheduler.from_pretrained(backbone_path, subfolder="scheduler")
|
105 |
+
|
106 |
+
controlnet = CogVideoXControlnet.from_pretrained(backbone_path, subfolder="transformer", num_layers=6, out_proj_dim_factor=64, out_proj_dim_zero_init=True, notextinflow=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
controlnet.eval()
|
108 |
+
controlnet.load_state_dict(controlnet_sd)
|
109 |
+
|
|
|
|
|
|
|
|
|
|
|
110 |
pipe = FloVDCogVideoXControlnetImageToVideoPipeline(
|
111 |
+
tokenizer=tokenizer,
|
112 |
+
text_encoder=text_encoder,
|
113 |
+
vae=vae,
|
114 |
+
transformer=transformer,
|
115 |
+
controlnet=controlnet,
|
116 |
+
scheduler=scheduler,
|
117 |
+
).to(device, dtype)
|
118 |
+
|
|
|
|
|
|
|
119 |
return pipe
|
120 |
|
121 |
+
|
122 |
def load_cogvideox_flovd_OMSM_lora_pipeline(omsm_path, backbone_path, transformer_lora_config, device, dtype):
|
123 |
tokenizer = AutoTokenizer.from_pretrained(backbone_path, subfolder="tokenizer")
|
124 |
+
text_encoder = T5EncoderModel.from_pretrained(backbone_path, subfolder="text_encoder", device_map="auto", torch_dtype=dtype)
|
125 |
+
transformer = CogVideoXTransformer3DModel.from_pretrained(backbone_path, subfolder="transformer", device_map="auto", torch_dtype=dtype)
|
126 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(backbone_path, subfolder="vae", device_map="auto", torch_dtype=dtype)
|
127 |
scheduler = CogVideoXDPMScheduler.from_pretrained(backbone_path, subfolder="scheduler")
|
128 |
|
|
|
129 |
transformer.add_adapter(transformer_lora_config)
|
|
|
130 |
lora_state_dict = FloVDOMSMCogVideoXImageToVideoPipeline.lora_state_dict(omsm_path)
|
131 |
transformer_state_dict = {
|
132 |
f'{k.replace("transformer.", "")}': v
|
133 |
for k, v in lora_state_dict.items()
|
134 |
if k.startswith("transformer.")
|
135 |
}
|
136 |
+
set_peft_model_state_dict(transformer, transformer_state_dict, adapter_name="default")
|
137 |
+
|
138 |
+
other_block_path = os.path.join(omsm_path, "selected_blocks.safetensors")
|
139 |
+
if os.path.exists(other_block_path):
|
140 |
+
tensor_dict = load_file(other_block_path)
|
141 |
+
for block_name, param in tensor_dict.items():
|
142 |
+
module_name, param_name = block_name.split(".", 1)
|
143 |
+
if hasattr(transformer, module_name):
|
144 |
+
getattr(transformer, module_name).load_state_dict({param_name: param}, strict=False)
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
pipe = FloVDOMSMCogVideoXImageToVideoPipeline(
|
147 |
+
tokenizer=tokenizer,
|
148 |
+
text_encoder=text_encoder,
|
149 |
+
vae=vae,
|
150 |
+
transformer=transformer,
|
151 |
+
scheduler=scheduler,
|
152 |
+
).to(device, dtype)
|
153 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
return pipe
|
155 |
|
156 |
|